From e29095146fcca26be76f71404735c8ae4ef8ef85 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 3 Jun 2025 03:20:15 +0000 Subject: [PATCH 01/90] WIP: Output modes --- pydantic_ai_slim/pydantic_ai/__init__.py | 4 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 28 +- pydantic_ai_slim/pydantic_ai/_output.py | 414 ++++++++++++++---- pydantic_ai_slim/pydantic_ai/agent.py | 1 + .../pydantic_ai/models/__init__.py | 33 +- .../pydantic_ai/models/anthropic.py | 2 +- .../pydantic_ai/models/bedrock.py | 2 +- .../pydantic_ai/models/function.py | 4 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 2 +- .../pydantic_ai/models/mistral.py | 8 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 57 ++- pydantic_ai_slim/pydantic_ai/models/test.py | 4 +- pydantic_ai_slim/pydantic_ai/result.py | 25 +- .../test_openai_json_schema_output.yaml | 223 ++++++++++ ...st_openai_json_schema_output_multiple.yaml | 293 +++++++++++++ .../test_openai_manual_json_output.yaml | 211 +++++++++ .../test_openai/test_openai_tool_output.yaml | 227 ++++++++++ tests/models/test_bedrock.py | 2 +- tests/models/test_fallback.py | 6 +- tests/models/test_gemini.py | 60 ++- tests/models/test_instrumented.py | 28 +- tests/models/test_model_request_parameters.py | 8 +- tests/models/test_openai.py | 352 ++++++++++++++- tests/test_agent.py | 15 +- tests/test_direct.py | 2 +- tests/test_logfire.py | 6 +- 28 files changed, 1878 insertions(+), 143 deletions(-) create mode 100644 tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_tool_output.yaml diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 21ef4dec6..af0cb8a15 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -12,7 +12,7 @@ ) from .format_prompt import format_as_xml from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl -from .result import ToolOutput +from .result import JSONSchemaOutput, ManualJSONOutput, ToolOutput from .tools import RunContext, Tool __all__ = ( @@ -43,6 +43,8 @@ 'RunContext', # result 'ToolOutput', + 'JSONSchemaOutput', + 'ManualJSONOutput', # format_prompt 'format_as_xml', ) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index bf8397174..f6ec91268 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -264,10 +264,29 @@ async def add_mcp_server_tools(server: MCPServer) -> None: function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] output_schema = ctx.deps.output_schema + model = ctx.deps.model + + # TODO: This is horrible + output_mode = None + output_object = None + output_tools = [] + require_tool_use = False + if output_schema: + output_mode = output_schema.forced_mode or model.default_output_mode + output_object = output_schema.object_schema.definition + output_tools = output_schema.tool_defs() + require_tool_use = output_mode == 'tool' and output_schema.allow_text_output != 'plain' + + supported_modes = model.supported_output_modes + if output_mode not in supported_modes: + raise exceptions.UserError(f"Output mode '{output_mode}' is not among supported modes: {supported_modes}") + return models.ModelRequestParameters( function_tools=function_tool_defs, - allow_text_output=_output.allow_text_output(output_schema), - output_tools=output_schema.tool_defs() if output_schema is not None else [], + output_mode=output_mode, + output_object=output_object, + output_tools=output_tools, + require_tool_use=require_tool_use, ) @@ -536,9 +555,12 @@ async def _handle_text_response( text = '\n\n'.join(texts) try: - if _output.allow_text_output(output_schema): + if output_schema is None or output_schema.allow_text_output == 'plain': # The following cast is safe because we know `str` is an allowed result type result_data = cast(NodeRunEndT, text) + elif output_schema.allow_text_output == 'json': + run_context = build_run_context(ctx) + result_data = await output_schema.process(text, run_context) else: m = _messages.RetryPromptPart( content='Plain text responses are not permitted, please include your response in a tool call', diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 3f2900c58..e9a5030bd 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,8 +1,10 @@ from __future__ import annotations as _annotations import inspect +import json from collections.abc import Awaitable, Iterable, Iterator, Sequence from dataclasses import dataclass, field +from textwrap import dedent from typing import Any, Callable, Generic, Literal, Union, cast from pydantic import TypeAdapter, ValidationError @@ -50,6 +52,17 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result' DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' +DEFAULT_MANUAL_JSON_PROMPT = dedent( # TODO: Move to ModelProfile + """ + Always respond with a JSON object matching this description and schema: + + {description} + + {schema} + + Don't include any text or Markdown fencing before or after. + """ +) @dataclass @@ -113,7 +126,7 @@ def __init__(self, tool_retry: _messages.RetryPromptPart): class ToolOutput(Generic[OutputDataT]): """Marker class to use tools for outputs, and customize the tool.""" - output_type: SimpleOutputType[OutputDataT] + output_type: OutputTypeOrFunction[OutputDataT] # TODO: Allow list of types instead of unions? name: str | None description: str | None max_retries: int | None @@ -121,7 +134,7 @@ class ToolOutput(Generic[OutputDataT]): def __init__( self, - type_: SimpleOutputType[OutputDataT], + type_: OutputTypeOrFunction[OutputDataT], *, name: str | None = None, description: str | None = None, @@ -135,20 +148,68 @@ def __init__( self.strict = strict +@dataclass(init=False) +class JSONSchemaOutput(Generic[OutputDataT]): + """Marker class to use JSON schema output for outputs.""" + + output_types: Sequence[OutputTypeOrFunction[OutputDataT]] + name: str | None + description: str | None + strict: bool | None + + def __init__( + self, + type_: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ): + self.output_types = flatten_output_types(type_) + self.name = name + self.description = description + self.strict = strict + + +class ManualJSONOutput(Generic[OutputDataT]): + """Marker class to use manual JSON mode for outputs.""" + + output_types: Sequence[OutputTypeOrFunction[OutputDataT]] + name: str | None + description: str | None + + def __init__( + self, + type_: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, + ): + self.output_types = flatten_output_types(type_) + self.name = name + self.description = description + + T_co = TypeVar('T_co', covariant=True) -# output_type=Type or output_type=function or output_type=object.method -SimpleOutputType = TypeAliasType( - 'SimpleOutputType', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,) -) -# output_type=ToolOutput() or -SimpleOutputTypeOrMarker = TypeAliasType( - 'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,) + +OutputTypeOrFunction = TypeAliasType( + 'OutputTypeOrFunction', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,) ) -# output_type= or [, ...] OutputType = TypeAliasType( - 'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,) + 'OutputType', + Union[ + OutputTypeOrFunction[T_co], + ToolOutput[T_co], + Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co]]], + JSONSchemaOutput[T_co], + ManualJSONOutput[T_co], + ], + type_params=(T_co,), ) +# TODO: Add `json_object` for old OpenAI models, or rename `json_schema` to `json` and choose automatically, relying on Pydantic validation +type OutputMode = Literal['tool', 'json_schema', 'manual_json'] + @dataclass class OutputSchema(Generic[OutputDataT]): @@ -157,8 +218,10 @@ class OutputSchema(Generic[OutputDataT]): Similar to `Tool` but for the final output of running an agent. """ + forced_mode: OutputMode | None + object_schema: OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] tools: dict[str, OutputTool[OutputDataT]] - allow_text_output: bool + allow_text_output: Literal['plain', 'json'] | None = None @classmethod def build( @@ -172,65 +235,91 @@ def build( if output_type is str: return None - output_types: Sequence[SimpleOutputTypeOrMarker[OutputDataT]] - if isinstance(output_type, Sequence): - output_types = output_type - else: - output_types = (output_type,) - - output_types_flat: list[SimpleOutputTypeOrMarker[OutputDataT]] = [] - for output_type in output_types: - if union_types := get_union_args(output_type): - output_types_flat.extend(union_types) - else: - output_types_flat.append(output_type) - - allow_text_output = False - if str in output_types_flat: - allow_text_output = True - output_types_flat = [t for t in output_types_flat if t is not str] - - multiple = len(output_types_flat) > 1 - - default_tool_name = name or DEFAULT_OUTPUT_TOOL_NAME - default_tool_description = description - default_tool_strict = strict - + forced_mode: OutputMode | None = None + allow_text_output: Literal['plain', 'json'] | None = 'plain' tools: dict[str, OutputTool[OutputDataT]] = {} - for output_type in output_types_flat: - tool_name = None - tool_description = None - tool_strict = None - if isinstance(output_type, ToolOutput): - tool_output_type = output_type.output_type - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads - tool_name = output_type.name - tool_description = output_type.description - tool_strict = output_type.strict - else: - tool_output_type = output_type - - if tool_name is None: - tool_name = default_tool_name - if multiple: - tool_name += f'_{tool_output_type.__name__}' - i = 1 - original_tool_name = tool_name - while tool_name in tools: - i += 1 - tool_name = f'{original_tool_name}_{i}' + output_types: Sequence[OutputTypeOrFunction[OutputDataT]] + if isinstance(output_type, JSONSchemaOutput): + forced_mode = 'json_schema' + output_types = output_type.output_types + name = output_type.name # TODO: If not set, use method arg? + description = output_type.description + strict = output_type.strict + allow_text_output = 'json' + elif isinstance(output_type, ManualJSONOutput): + forced_mode = 'manual_json' + output_types = output_type.output_types + name = output_type.name + description = output_type.description + allow_text_output = 'json' + else: + # TODO: We can't always force tool mode here, because some models may not support tools but will work with manual_json + output_types_or_tool_outputs = flatten_output_types(output_type) + + if str in output_types_or_tool_outputs: + forced_mode = 'tool' + allow_text_output = 'plain' + # TODO: What if str is the only item, e.g. `output_type=[str]` + output_types_or_tool_outputs = [t for t in output_types_or_tool_outputs if t is not str] + + multiple = len(output_types_or_tool_outputs) > 1 + + default_tool_name = name or DEFAULT_OUTPUT_TOOL_NAME + default_tool_description = description + default_tool_strict = strict + + output_types = [] + for output_type_or_tool_output in output_types_or_tool_outputs: + tool_name = None + tool_description = None + tool_strict = None + if isinstance(output_type_or_tool_output, ToolOutput): + forced_mode = 'tool' + tool_output = output_type_or_tool_output + output_type = tool_output.output_type + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads + tool_name = tool_output.name + tool_description = tool_output.description + tool_strict = tool_output.strict + else: + output_type = output_type_or_tool_output + + if tool_name is None: + tool_name = default_tool_name + if multiple: + tool_name += f'_{output_type.__name__}' + + i = 1 + original_tool_name = tool_name + while tool_name in tools: + i += 1 + tool_name = f'{original_tool_name}_{i}' + + tool_description = tool_description or default_tool_description + if tool_strict is None: + tool_strict = default_tool_strict + + parameters_schema = OutputObjectSchema( + output_type=output_type, description=tool_description, strict=tool_strict + ) + tools[tool_name] = OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple) + output_types.append(output_type) - tool_description = tool_description or default_tool_description - if tool_strict is None: - tool_strict = default_tool_strict + output_types = flatten_output_types(output_types) - parameters_schema = OutputObjectSchema( - output_type=tool_output_type, description=tool_description, strict=tool_strict + if len(output_types) > 1: + output_object_schema = OutputUnionSchema( + output_types=output_types, name=name, description=description, strict=strict + ) + else: + output_object_schema = OutputObjectSchema( + output_type=output_types[0], name=name, description=description, strict=strict ) - tools[tool_name] = OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple) return cls( + forced_mode=forced_mode, + object_schema=output_object_schema, tools=tools, allow_text_output=allow_text_output, ) @@ -262,9 +351,32 @@ def tool_defs(self) -> list[ToolDefinition]: """Get tool definitions to register with the model.""" return [t.tool_def for t in self.tools.values()] + async def process( + self, + data: str | dict[str, Any], + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + """Validate an output message. + + Args: + data: The output data to validate. + run_context: The current run context. + allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. + + Returns: + Either the validated output data (left) or a retry message (right). + """ + return await self.object_schema.process( + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool: - return output_schema is None or output_schema.allow_text_output + # TODO: Add plain/json argument? + return output_schema is None or output_schema.allow_text_output is not None @dataclass @@ -274,27 +386,125 @@ class OutputObjectDefinition: description: str | None = None strict: bool | None = None + @property + def manual_json_instructions(self) -> str: + """Get instructions for model to output manual JSON matching the schema.""" + description = ': '.join([v for v in [self.name, self.description] if v]) + return DEFAULT_MANUAL_JSON_PROMPT.format(schema=json.dumps(self.json_schema), description=description) + + +@dataclass(init=False) +class OutputUnionDataEntry: + kind: str + data: dict[str, Any] + + +@dataclass(init=False) +class OutputUnionData: + result: OutputUnionDataEntry + + +# TODO: Better class naming +@dataclass(init=False) +class OutputUnionSchema(Generic[OutputDataT]): + definition: OutputObjectDefinition + outer_typed_dict_key: str = 'result' + _root_object_schema: OutputObjectSchema[OutputUnionData] + _object_schemas: dict[str, OutputObjectSchema[OutputDataT]] + + def __init__( + self, + output_types: Sequence[OutputTypeOrFunction[OutputDataT]], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ): + self._object_schemas = {} + # TODO: Ensure keys are unique + self._object_schemas = { + output_type.__name__: OutputObjectSchema(output_type=output_type) for output_type in output_types + } + + self._root_object_schema = OutputObjectSchema(output_type=OutputUnionData) + + # TODO: Account for conflicting $defs and $refs + json_schema = { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': { + 'const': name, + }, + 'data': object_schema.definition.json_schema, # TODO: Pop description here? + }, + 'description': object_schema.definition.description or name, # TODO: Better description + 'required': ['kind', 'data'], + 'additionalProperties': False, + } + for name, object_schema in self._object_schemas.items() + ], + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + + self.definition = OutputObjectDefinition( + name=name or DEFAULT_OUTPUT_TOOL_NAME, + description=description or DEFAULT_OUTPUT_TOOL_DESCRIPTION, + json_schema=json_schema, + strict=strict, + ) + + async def process( + self, + data: str | dict[str, Any] | None, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + # TODO: Error handling? + result = await self._root_object_schema.process( + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + + result = result.result + kind = result.kind + data = result.data + try: + object_schema = self._object_schemas[kind] + except KeyError as e: + raise ToolRetryError(_messages.RetryPromptPart(content=f'Invalid kind: {kind}')) from e + + return await object_schema.process( + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + @dataclass(init=False) class OutputObjectSchema(Generic[OutputDataT]): definition: OutputObjectDefinition - validator: SchemaValidator - function_schema: _function_schema.FunctionSchema | None = None outer_typed_dict_key: str | None = None + _validator: SchemaValidator + _function_schema: _function_schema.FunctionSchema | None = None def __init__( self, *, - output_type: SimpleOutputType[OutputDataT], + output_type: OutputTypeOrFunction[OutputDataT], name: str | None = None, description: str | None = None, strict: bool | None = None, ): if inspect.isfunction(output_type) or inspect.ismethod(output_type): - self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) - self.validator = self.function_schema.validator - json_schema = self.function_schema.json_schema - json_schema['description'] = self.function_schema.description + self._function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) + self._validator = self._function_schema.validator + json_schema = self._function_schema.json_schema + json_schema['description'] = self._function_schema.description else: type_adapter: TypeAdapter[Any] if _utils.is_model_like(output_type): @@ -308,7 +518,7 @@ def __init__( type_adapter = TypeAdapter(response_data_typed_dict) # Really a PluggableSchemaValidator, but it's API-compatible - self.validator = cast(SchemaValidator, type_adapter.validator) + self._validator = cast(SchemaValidator, type_adapter.validator) json_schema = _utils.check_object_json_schema( type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) @@ -335,6 +545,7 @@ async def process( data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], allow_partial: bool = False, + wrap_validation_errors: bool = True, ) -> OutputDataT: """Process an output message, performing validation and (if necessary) calling the output function. @@ -342,18 +553,37 @@ async def process( data: The output data to validate. run_context: The current run context. allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. Returns: Either the validated output data (left) or a retry message (right). """ - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(data, str): - output = self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) - else: - output = self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) - - if self.function_schema: - output = await self.function_schema.call(output, run_context) + try: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + else: + output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + except ValidationError as e: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=e.errors(include_url=False), + ) + raise ToolRetryError(m) from e + else: + raise + + if self._function_schema: + try: + output = await self._function_schema.call(output, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise if k := self.outer_typed_dict_key: output = output[k] @@ -402,7 +632,9 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - output = await self.parameters_schema.process(tool_call.args, run_context, allow_partial=allow_partial) + output = await self.parameters_schema.process( + tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False + ) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -437,3 +669,21 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return get_args(tp) else: return () + + +def flatten_output_types[T]( + output_type: T | Sequence[T], +) -> list[T]: + output_types: Sequence[T] + if isinstance(output_type, Sequence): + output_types = output_type + else: + output_types = (output_type,) + + output_types_flat: list[T] = [] + for output_type in output_types: + if union_types := get_union_args(output_type): + output_types_flat.extend(union_types) + else: + output_types_flat.append(output_type) + return output_types_flat diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 57010f5f8..abcb32818 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -326,6 +326,7 @@ def __init__( self._instructions_functions = [] if isinstance(instructions, (str, Callable)): instructions = [instructions] + # TODO: Add OutputSchema to the instructions in JSON mode for instruction in instructions or []: if isinstance(instruction, str): self._instructions += instruction + '\n' diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 714080305..1961938f2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -18,6 +18,7 @@ from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec +from .._output import OutputMode, OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent @@ -292,8 +293,13 @@ class ModelRequestParameters: """Configuration for an agent's request to a model, specifically related to tools and output handling.""" function_tools: list[ToolDefinition] = field(default_factory=list) - allow_text_output: bool = True + + output_mode: OutputMode | None = None + output_object: OutputObjectDefinition | None = None output_tools: list[ToolDefinition] = field(default_factory=list) + require_tool_use: bool = ( + True # TODO: Rename back to allow_text_output because this is public API. Support bool as well as plain/json + ) class Model(ABC): @@ -338,6 +344,11 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools], output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools], ) + if output_object := model_request_parameters.output_object: + model_request_parameters = replace( + model_request_parameters, + output_object=_customize_output_object(transformer, output_object), + ) return model_request_parameters @@ -419,6 +430,18 @@ def _get_instructions(messages: list[ModelMessage]) -> str | None: return None + @property + def supported_output_modes(self) -> set[OutputMode]: + """The supported output modes for the model.""" + # TODO: Move to ModelProfile + return {'tool'} # TODO: Support manual_json on all + + @property + def default_output_mode(self) -> OutputMode: + """The default output mode for the model.""" + # TODO: Move to ModelProfile + return 'tool' + @dataclass class StreamedResponse(ABC): @@ -620,3 +643,11 @@ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinit if t.strict is None: t = replace(t, strict=schema_transformer.is_strict_compatible) return replace(t, parameters_json_schema=parameters_json_schema) + + +def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition): + schema_transformer = transformer(o.json_schema, strict=o.strict) + son_schema = schema_transformer.walk() + if o.strict is None: + o = replace(o, strict=schema_transformer.is_strict_compatible) + return replace(o, json_schema=son_schema) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 69808aa5f..e75cb4a18 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -213,7 +213,7 @@ async def _messages_create( if not tools: tool_choice = None else: - if not model_request_parameters.allow_text_output: + if model_request_parameters.require_tool_use: tool_choice = {'type': 'any'} else: tool_choice = {'type': 'auto'} diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index fb8753b43..6e4e4797e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -378,7 +378,7 @@ def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> return None tool_choice: ToolChoiceTypeDef - if not model_request_parameters.allow_text_output: + if model_request_parameters.require_tool_use: tool_choice = {'any': {}} else: tool_choice = {'auto': {}} diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 22bcddffb..934707cce 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -91,7 +91,7 @@ async def request( ) -> ModelResponse: agent_info = AgentInfo( model_request_parameters.function_tools, - model_request_parameters.allow_text_output, + not model_request_parameters.require_tool_use, model_request_parameters.output_tools, model_settings, ) @@ -120,7 +120,7 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: agent_info = AgentInfo( model_request_parameters.function_tools, - model_request_parameters.allow_text_output, + not model_request_parameters.require_tool_use, model_request_parameters.output_tools, model_settings, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index a93c80a97..aa8c23eed 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -192,7 +192,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: - if model_request_parameters.allow_text_output: + if not model_request_parameters.require_tool_use: return None elif tools: return _tool_config([t['name'] for t in tools['function_declarations']]) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 8af4261ea..38799712e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -213,7 +213,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None ) -> ToolConfigDict | None: - if model_request_parameters.allow_text_output: + if not model_request_parameters.require_tool_use: return None elif tools: names: list[str] = [] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 067cfb516..c6dfed3bd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -209,7 +209,7 @@ async def _completions_create( # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: + elif model_request_parameters.require_tool_use: tool_choice = 'required' else: tool_choice = 'auto' diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index a0838edbc..0ff134a91 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -250,6 +250,7 @@ async def _stream_completions_create( ) elif model_request_parameters.output_tools: + # TODO: Port to native "manual JSON" mode # Json Mode parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools] user_output_format_message = self._generate_user_output_format(parameters_json_schemas) @@ -258,7 +259,9 @@ async def _stream_completions_create( response = await self.client.chat.stream_async( model=str(self._model_name), messages=mistral_messages, - response_format={'type': 'json_object'}, + response_format={ + 'type': 'json_object' + }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9 stream=True, http_headers={'User-Agent': get_user_agent()}, ) @@ -284,7 +287,7 @@ def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> """ if not model_request_parameters.function_tools and not model_request_parameters.output_tools: return None - elif not model_request_parameters.allow_text_output: + elif model_request_parameters.require_tool_use: return 'required' else: return 'auto' @@ -566,6 +569,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Attempt to produce an output tool call from the received text if self._output_tools: self._delta_content += text + # TODO: Port to native "manual JSON" mode maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools) if maybe_tool_call_part: yield self._parts_manager.handle_tool_call_part( diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 4e99fb574..1f0675123 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -14,6 +14,7 @@ from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._output import OutputMode, OutputObjectDefinition from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( AudioUrl, @@ -237,6 +238,11 @@ def system(self) -> str: """The system / model provider.""" return self._system + @property + def supported_output_modes(self) -> set[OutputMode]: + """The supported output modes for the model.""" + return {'tool', 'json_schema', 'manual_json'} + @overload async def _completions_create( self, @@ -262,18 +268,31 @@ async def _completions_create( model_settings: OpenAIModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: - tools = self._get_tools(model_request_parameters) + openai_messages = await self._map_messages(messages) + + tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] + response_format: chat.completion_create_params.ResponseFormat | None = None + + if model_request_parameters.output_mode == 'tool': + tools.extend(self._map_tool_definition(r) for r in model_request_parameters.output_tools) + elif output_object := model_request_parameters.output_object: + if model_request_parameters.output_mode == 'json_schema': + response_format = self._map_output_object_definition(output_object) + elif model_request_parameters.output_mode == 'manual_json': + openai_messages.insert( + 0, + chat.ChatCompletionSystemMessageParam( + content=output_object.manual_json_instructions, role='system' + ), + ) - # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: + elif model_request_parameters.require_tool_use: tool_choice = 'required' else: tool_choice = 'auto' - openai_messages = await self._map_messages(messages) - try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) @@ -290,6 +309,7 @@ async def _completions_create( temperature=model_settings.get('temperature', NOT_GIVEN), top_p=model_settings.get('top_p', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), + response_format=response_format or NOT_GIVEN, seed=model_settings.get('seed', NOT_GIVEN), presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN), frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), @@ -404,6 +424,22 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, ) + @staticmethod + def _map_output_object_definition(o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: + # TODO: Use ResponseFormatJSONObject on older models + response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage] + 'type': 'json_schema', + 'json_schema': { + 'name': o.name, + 'schema': o.json_schema, + }, + } + if o.description: + response_format_param['json_schema']['description'] = o.description + if o.strict: + response_format_param['json_schema']['strict'] = o.strict + return response_format_param + def _map_tool_definition(self, f: ToolDefinition) -> chat.ChatCompletionToolParam: tool_param: chat.ChatCompletionToolParam = { 'type': 'function', @@ -563,6 +599,11 @@ def system(self) -> str: """The system / model provider.""" return self._system + @property + def supported_output_modes(self) -> set[OutputMode]: + """The supported output modes for the model.""" + return {'tool', 'json_schema'} + async def request( self, messages: list[ModelRequest | ModelResponse], @@ -646,7 +687,7 @@ async def _responses_create( # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif not model_request_parameters.allow_text_output: + elif model_request_parameters.require_tool_use: tool_choice = 'required' else: tool_choice = 'auto' @@ -657,6 +698,8 @@ async def _responses_create( try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) + # TODO: Pass text.format = ResponseFormatTextJSONSchemaConfigParam(...): {'type': 'json_schema', 'strict': True, 'name': '...', 'schema': ...} + # TODO: Fall back on ResponseFormatJSONObject/json_object on older models? return await self.client.responses.create( input=openai_messages, model=self._model_name, @@ -853,6 +896,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content is not None: + # TODO: Handle structured output yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) for dtc in choice.delta.tool_calls or []: @@ -934,6 +978,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass elif isinstance(chunk, responses.ResponseTextDeltaEvent): + # TODO: Handle structured output yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta) elif isinstance(chunk, responses.ResponseTextDoneEvent): diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 0daad25bc..9079afa5e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -130,7 +130,7 @@ def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> l def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput: if self.custom_output_text is not None: - assert model_request_parameters.allow_text_output, ( + assert not model_request_parameters.require_tool_use, ( 'Plain response not allowed, but `custom_output_text` is set.' ) assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.' @@ -145,7 +145,7 @@ def _get_output(self, model_request_parameters: ModelRequestParameters) -> _Wrap return _WrappedToolOutput({k: self.custom_output_args}) else: return _WrappedToolOutput(self.custom_output_args) - elif model_request_parameters.allow_text_output: + elif not model_request_parameters.require_tool_use: return _WrappedTextOutput(None) elif model_request_parameters.output_tools: return _WrappedToolOutput(None) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 443e98b32..6c1007266 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -11,6 +11,8 @@ from . import _output, _utils, exceptions, messages as _messages, models from ._output import ( + JSONSchemaOutput, + ManualJSONOutput, OutputDataT, OutputDataT_inv, OutputSchema, @@ -22,7 +24,7 @@ from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'OutputValidatorFunc' +__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'JSONSchemaOutput', 'ManualJSONOutput', 'OutputValidatorFunc' T = TypeVar('T') @@ -93,8 +95,14 @@ async def _validate_response( ) else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - # The following cast is safe because we know `str` is an allowed output type - result_data = cast(OutputDataT, text) + + if self._output_schema is None or self._output_schema.allow_text_output == 'plain': + # The following cast is safe because we know `str` is an allowed output type + result_data = cast(OutputDataT, text) + else: + result_data = await self._output_schema.process( + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) @@ -311,7 +319,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ - if self._output_schema and not self._output_schema.allow_text_output: + if self._output_schema and self._output_schema.allow_text_output != 'plain': raise exceptions.UserError('stream_text() can only be used with text responses') if delta: @@ -403,7 +411,14 @@ async def validate_structured_output( ) else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - result_data = cast(OutputDataT, text) + + if self._output_schema is None or self._output_schema.allow_text_output == 'plain': + # The following cast is safe because we know `str` is an allowed output type + result_data = cast(OutputDataT, text) + else: + result_data = await self._output_schema.process( + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover diff --git a/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml b/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml new file mode 100644 index 000000000..ff4477f3d --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml @@ -0,0 +1,223 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '522' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + response_format: + json_schema: + name: result + schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '341' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_PkRGedQNRFUzJp2R7dO7avWR + type: function + created: 1746142582 + id: chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 71 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 83 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '753' + content-type: + - application/json + cookie: + - __cf_bm=dOa3_E1SoWV.vbgZ8L7tx8o9S.XyNTE.YS0K0I3JHq4-1746142583-1.0.1.1-0TuvhdYsoD.J1522DBXH0yrAP_M9MlzvlcpyfwQQNZy.KO5gri6ejQ.gFuwLV5hGhuY0W2uI1dN7ZF1lirVHKeEnEz5s_89aJjrMWjyBd8M; + _cfuvid=xQIJVHkOP28w5fPnAvDHPiCRlU7kkNj6iFV87W4u8Ds-1746142583128-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_PkRGedQNRFUzJp2R7dO7avWR + type: function + - content: Mexico + role: tool + tool_call_id: call_PkRGedQNRFUzJp2R7dO7avWR + model: gpt-4o + n: 1 + response_format: + json_schema: + name: result + schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '852' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '553' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"city":"Mexico City","country":"Mexico"}' + refusal: null + role: assistant + created: 1746142583 + id: chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 15 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 92 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 107 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml new file mode 100644 index 000000000..bda52f9bf --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml @@ -0,0 +1,293 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1120' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + response_format: + json_schema: + description: The final response which ends this conversation + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + required: + - kind + - data + type: object + required: + - result + type: object + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '999' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_NiLmkD3Yi30ax2IY7t14e3AP + type: function + created: 1748919916 + id: chatcmpl-BeCFgrwfENi1OwavvP8itSOMTKjwY + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 160 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 171 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1351' + content-type: + - application/json + cookie: + - __cf_bm=4QxpNnP8_u4FyljGgYUUF5NYWCyqa2OJvgxKjyMEh4Y-1748919917-1.0.1.1-LsEozdwJs4K6NOAxtY3kw9dzZ6JHe.l4h4qkIENfShBXiUE6C5V9ED_hCbYeM.GMdC13g7SAlw1iuh5HCTMtOzNvTr_j_jvPbLY3p35HCbM; + _cfuvid=.3C6J8WR_NWUd_EBaQOj9bFncgO.R9A8576Zi3GczTg-1748919917308-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_NiLmkD3Yi30ax2IY7t14e3AP + type: function + - content: Mexico + role: tool + tool_call_id: call_NiLmkD3Yi30ax2IY7t14e3AP + model: gpt-4o + response_format: + json_schema: + description: The final response which ends this conversation + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + required: + - kind + - data + type: object + required: + - result + type: object + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '903' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '867' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + refusal: null + role: assistant + created: 1748919918 + id: chatcmpl-BeCFiQtbjmFUzYbmXlkAEWbc0peoL + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 25 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 181 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 206 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml b/tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml new file mode 100644 index 000000000..56023f426 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml @@ -0,0 +1,211 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '627' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |2 + + Always respond with a JSON object matching this description and schema: + + CityLocation + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + stream: false + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '430' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_uTjt2vMkeTr0GYqQyQYrUUhl + type: function + created: 1747154400 + id: chatcmpl-BWmxcJf2wXM37xTze50IfAoyuaoKb + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_55d88aaf2f + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 106 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 118 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '858' + content-type: + - application/json + cookie: + - __cf_bm=95NT6qevASASUyV3RVHQoxZGp8lnU1dQzcdShJ0rQ8o-1747154400-1.0.1.1-zowTt2i3mTZlYQ8gezUuRRLY_0dw6L6iD5qfaNySs0KmHmLd2JFwYun1kZJ61S03BecMhUdxy.FiOWLq2LdY.RuTR7vePLyoCrMmCDa4vpk; + _cfuvid=hgD2spnngVs.0HuyvQx7_W1uCro2gMmGvsKkZTUk3H0-1747154400314-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |2 + + Always respond with a JSON object matching this description and schema: + + CityLocation + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_uTjt2vMkeTr0GYqQyQYrUUhl + type: function + - content: Mexico + role: tool + tool_call_id: call_uTjt2vMkeTr0GYqQyQYrUUhl + model: gpt-4o + n: 1 + stream: false + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '853' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '2453' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"city":"Mexico City","country":"Mexico"}' + refusal: null + role: assistant + created: 1747154401 + id: chatcmpl-BWmxdIs5pO5RCbQ9qRxxtWVItB4NU + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_d8864f8b6b + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 127 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 139 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_tool_output.yaml b/tests/models/cassettes/test_openai/test_openai_tool_output.yaml new file mode 100644 index 000000000..56f7441f1 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_tool_output.yaml @@ -0,0 +1,227 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '561' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + stream: false + tool_choice: required + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '348' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_iXFttys57ap0o16JSlC8yhYo + type: function + created: 1746142584 + id: chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 68 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 80 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '792' + content-type: + - application/json + cookie: + - __cf_bm=yM.C6I_kAzJk3Dm7H52actN1zAEW8fj.Gd2yeJ7tKN0-1746142584-1.0.1.1-xk91aElDtLLC8aROrOKHlp5vck_h.R.zQkS6OrsiBOwuFA8rE1kGswpactMEtYxV9WgWDN2B4S2B4zs8heyxmcfiNjmOf075n.OPqYpVla4; + _cfuvid=JCllInpf6fg1JdOS7xSj3bZOXYf9PYJ8uoamRTx7ku4-1746142584855-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_iXFttys57ap0o16JSlC8yhYo + type: function + - content: Mexico + role: tool + tool_call_id: call_iXFttys57ap0o16JSlC8yhYo + model: gpt-4o + n: 1 + stream: false + tool_choice: required + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1113' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1919' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{"city": "Mexico City", "country": "Mexico"}' + name: final_result + id: call_gmD2oUZUzSoCkmNmp3JPUF7R + type: function + created: 1746142585 + id: chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 36 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 89 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 125 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 0a0dec8f8..35578309c 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -684,7 +684,7 @@ async def test_bedrock_anthropic_no_tool_choice(bedrock_provider: BedrockProvide 'This is my tool', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, ) - mrp = ModelRequestParameters(function_tools=[my_tool], allow_text_output=False, output_tools=[]) + mrp = ModelRequestParameters(function_tools=[my_tool], require_tool_use=True, output_tools=[]) # Models other than Anthropic support tool_choice model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index db6277527..ad6c1bab1 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -127,7 +127,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:success_response:', 'gen_ai.system': 'function', @@ -200,7 +200,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function::failure_response_stream,function::success_response_stream', 'gen_ai.system': 'function', @@ -272,7 +272,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'fallback:function,function', 'gen_ai.request.model': 'fallback:function:failure_response:,function:failure_response:', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:failure_response:', diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 13d831e4b..ca26dd408 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -63,7 +63,9 @@ async def test_model_simple(allow_model_requests: None): assert m.model_name == 'gemini-1.5-flash' assert 'x-goog-api-key' in m.client.headers - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[]) + mrp = ModelRequestParameters( + function_tools=[], require_tool_use=False, output_tools=[], output_mode=None, output_object=None + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -96,7 +98,13 @@ async def test_model_tools(allow_model_requests: None): {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']}, ) - mrp = ModelRequestParameters(function_tools=tools, allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=tools, + require_tool_use=False, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -138,7 +146,13 @@ async def test_require_response_tool(allow_model_requests: None): 'This is the tool for the final Result', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=False, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + require_tool_use=True, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -219,7 +233,13 @@ class Locations(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + require_tool_use=False, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( { @@ -298,7 +318,13 @@ class QueryDetails(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + output_mode=None, + require_tool_use=False, + output_tools=[output_tool], + output_object=None, + ) mrp = m.customize_request_parameters(mrp) # This tests that the enum values are properly converted to strings for Gemini @@ -340,7 +366,13 @@ class Locations(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + require_tool_use=False, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( _GeminiTools( @@ -404,7 +436,13 @@ class Location(BaseModel): json_schema, ) with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'): - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + require_tool_use=False, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) @@ -436,7 +474,13 @@ class FormattedStringFields(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + require_tool_use=False, + output_tools=[output_tool], + output_mode=None, + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( _GeminiTools( diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index f7caad399..5bb8204db 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -132,8 +132,10 @@ async def test_instrumented_model(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], + output_mode=None, + output_object=None, ), ) @@ -151,7 +153,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -328,8 +330,10 @@ async def test_instrumented_model_not_recording(): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], + output_mode=None, + output_object=None, ), ) @@ -350,8 +354,10 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], + output_mode=None, + output_object=None, ), ) as response_stream: assert [event async for event in response_stream] == snapshot( @@ -375,7 +381,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -438,8 +444,10 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], + output_mode=None, + output_object=None, ), ) as response_stream: async for event in response_stream: # pragma: no branch @@ -460,7 +468,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -541,8 +549,10 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - allow_text_output=True, + require_tool_use=False, output_tools=[], + output_mode=None, + output_object=None, ), ) @@ -560,7 +570,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', 'logfire.span_type': 'span', diff --git a/tests/models/test_model_request_parameters.py b/tests/models/test_model_request_parameters.py index 03910db11..5a0918211 100644 --- a/tests/models/test_model_request_parameters.py +++ b/tests/models/test_model_request_parameters.py @@ -4,9 +4,13 @@ def test_model_request_parameters_are_serializable(): - params = ModelRequestParameters(function_tools=[], allow_text_output=False, output_tools=[]) + params = ModelRequestParameters( + function_tools=[], output_mode=None, require_tool_use=False, output_tools=[], output_object=None + ) assert TypeAdapter(ModelRequestParameters).dump_python(params) == { 'function_tools': [], - 'allow_text_output': False, + 'preferred_output_mode': None, + 'require_tool_use': False, 'output_tools': [], + 'output_object': None, } diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index b0c8cbf62..d8702e384 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -15,6 +15,7 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior +from pydantic_ai._output import ManualJSONOutput from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -34,7 +35,7 @@ from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import Usage +from pydantic_ai.result import JSONSchemaOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition @@ -1710,3 +1711,352 @@ def test_model_profile_strict_not_supported(): }, } ) + + +@pytest.mark.vcr() +async def test_openai_tool_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=68, + response_tokens=12, + total_tokens=80, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City", "country": "Mexico"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage( + requests=1, + request_tokens=89, + response_tokens=36, + total_tokens=125, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_json_schema_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JSONSchemaOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR') + ], + usage=Usage( + requests=1, + request_tokens=71, + response_tokens=12, + total_tokens=83, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + requests=1, + request_tokens=92, + response_tokens=15, + total_tokens=107, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_json_schema_output_multiple(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + # TODO: Test with functions! + agent = Agent(m, output_type=JSONSchemaOutput([CityLocation, CountryLanguage])) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + # TODO: Show what response_format looks like + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_NiLmkD3Yi30ax2IY7t14e3AP') + ], + usage=Usage( + requests=1, + request_tokens=160, + response_tokens=11, + total_tokens=171, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BeCFgrwfENi1OwavvP8itSOMTKjwY', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_NiLmkD3Yi30ax2IY7t14e3AP', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=181, + response_tokens=25, + total_tokens=206, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BeCFiQtbjmFUzYbmXlkAEWbc0peoL', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_manual_json_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ManualJSONOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_uTjt2vMkeTr0GYqQyQYrUUhl') + ], + usage=Usage( + requests=1, + request_tokens=106, + response_tokens=12, + total_tokens=118, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BWmxcJf2wXM37xTze50IfAoyuaoKb', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_uTjt2vMkeTr0GYqQyQYrUUhl', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + requests=1, + request_tokens=127, + response_tokens=12, + total_tokens=139, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BWmxdIs5pO5RCbQ9qRxxtWVItB4NU', + ), + ] + ) diff --git a/tests/test_agent.py b/tests/test_agent.py index d59848155..6be9f8e30 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -13,6 +13,7 @@ from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai._output import ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.messages import ( BinaryContent, @@ -30,7 +31,7 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import ToolOutput, Usage +from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -261,7 +262,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: args_json = '{"response": ["foo", "bar"]}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) - agent = Agent(FunctionModel(return_tuple), output_type=tuple[str, str]) + agent = Agent(FunctionModel(return_tuple), output_type=ToolOutput(tuple[str, str])) result = agent.run_sync('Hello') assert result.output == ('foo', 'bar') @@ -353,14 +354,14 @@ def test_response_tuple(): m = TestModel() agent = Agent(m, output_type=tuple[str, str]) - assert agent._output_schema.allow_text_output is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + assert agent._output_schema.allow_text_output is None # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] result = agent.run_sync('Hello') assert result.output == snapshot(('a', 'a')) assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.allow_text_output is False + assert m.last_model_request_parameters.require_tool_use is True assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 1 @@ -410,7 +411,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: got_tool_call_name = ctx.tool_name return o - assert agent._output_schema.allow_text_output is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + assert agent._output_schema.allow_text_output == 'plain' # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] result = agent.run_sync('Hello') assert result.output == snapshot('success (no tool calls)') @@ -418,7 +419,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.allow_text_output is True + assert m.last_model_request_parameters.require_tool_use is False assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 1 @@ -496,7 +497,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.allow_text_output is False + assert m.last_model_request_parameters.require_tool_use is True assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 2 diff --git a/tests/test_direct.py b/tests/test_direct.py index 46f409ef0..16c2b0bd2 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -51,7 +51,7 @@ async def test_model_request_tool_call(): function_tools=[ ToolDefinition(name='tool_name', description='', parameters_json_schema={'type': 'object'}) ], - allow_text_output=False, + require_tool_use=True, ), ) assert model_response == snapshot( diff --git a/tests/test_logfire.py b/tests/test_logfire.py index e63f358f3..eff2c62ba 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -223,8 +223,10 @@ async def my_ret(x: int) -> str: 'strict': None, } ], - 'allow_text_output': True, + 'output_mode': None, 'output_tools': [], + 'output_object': None, + 'require_tool_use': False, } ) ), @@ -404,7 +406,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'test', 'gen_ai.request.model': 'test', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', 'logfire.span_type': 'span', 'logfire.msg': 'chat test', 'gen_ai.usage.input_tokens': 51, From 2056539a1a75d1e76b5786e51d1c6fe0487fe2f1 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 3 Jun 2025 05:05:26 +0000 Subject: [PATCH 02/90] WIP: More output modes --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 66 ++++++--------- pydantic_ai_slim/pydantic_ai/_output.py | 80 ++++++++++++------- pydantic_ai_slim/pydantic_ai/agent.py | 51 ++++++++---- .../pydantic_ai/models/__init__.py | 18 +---- .../pydantic_ai/models/anthropic.py | 2 +- .../pydantic_ai/models/bedrock.py | 2 +- .../pydantic_ai/models/function.py | 4 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 2 +- .../pydantic_ai/models/mistral.py | 2 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 35 +++----- pydantic_ai_slim/pydantic_ai/models/test.py | 4 +- .../pydantic_ai/profiles/__init__.py | 7 +- .../pydantic_ai/profiles/openai.py | 2 +- pydantic_ai_slim/pydantic_ai/result.py | 43 ++++------ tests/models/test_bedrock.py | 2 +- tests/models/test_fallback.py | 6 +- tests/models/test_gemini.py | 30 +++---- tests/models/test_instrumented.py | 28 +++---- tests/models/test_model_request_parameters.py | 6 +- tests/models/test_openai.py | 22 ++++- tests/test_agent.py | 10 +-- tests/test_direct.py | 2 +- tests/test_logfire.py | 6 +- tests/test_tools.py | 2 +- 26 files changed, 216 insertions(+), 220 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index f6ec91268..e6fdc7224 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union from opentelemetry.trace import Tracer from typing_extensions import TypeGuard, TypeVar, assert_never @@ -90,7 +90,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): end_strategy: EndStrategy get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] - output_schema: _output.OutputSchema[OutputDataT] | None + output_schema: _output.OutputSchema[OutputDataT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) @@ -264,29 +264,14 @@ async def add_mcp_server_tools(server: MCPServer) -> None: function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] output_schema = ctx.deps.output_schema - model = ctx.deps.model - - # TODO: This is horrible - output_mode = None - output_object = None - output_tools = [] - require_tool_use = False - if output_schema: - output_mode = output_schema.forced_mode or model.default_output_mode - output_object = output_schema.object_schema.definition - output_tools = output_schema.tool_defs() - require_tool_use = output_mode == 'tool' and output_schema.allow_text_output != 'plain' - - supported_modes = model.supported_output_modes - if output_mode not in supported_modes: - raise exceptions.UserError(f"Output mode '{output_mode}' is not among supported modes: {supported_modes}") + assert output_schema.mode is not None # Should have been set in agent._prepare_output_schema return models.ModelRequestParameters( function_tools=function_tool_defs, - output_mode=output_mode, - output_object=output_object, - output_tools=output_tools, - require_tool_use=require_tool_use, + output_mode=output_schema.mode, + output_object=output_schema.object_schema.definition if output_schema.object_schema else None, + output_tools=output_schema.tool_defs(), + allow_text_output=output_schema.allow_text_output == 'plain', ) @@ -471,7 +456,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # when the model has already returned text along side tool calls # in this scenario, if text responses are allowed, we return text from the most recent model # response, if any - if _output.allow_text_output(ctx.deps.output_schema): + if ctx.deps.output_schema.allow_text_output: for message in reversed(ctx.state.message_history): if isinstance(message, _messages.ModelResponse): last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)] @@ -497,19 +482,18 @@ async def _handle_tool_calls( # first, look for the output tool call final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] - if output_schema is not None: - for call, output_tool in output_schema.find_tool(tool_calls): - try: - result_data = await output_tool.process(call, run_context) - result_data = await _validate_output(result_data, ctx, call) - except _output.ToolRetryError as e: - # TODO: Should only increment retry stuff once per node execution, not for each tool call - # Also, should increment the tool-specific retry count rather than the run retry count - ctx.state.increment_retries(ctx.deps.max_result_retries, e) - parts.append(e.tool_retry) - else: - final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - break + for call, output_tool in output_schema.find_tool(tool_calls): + try: + result_data = await output_tool.process(call, run_context) + result_data = await _validate_output(result_data, ctx, call) + except _output.ToolRetryError as e: + # TODO: Should only increment retry stuff once per node execution, not for each tool call + # Also, should increment the tool-specific retry count rather than the run retry count + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + parts.append(e.tool_retry) + else: + final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) + break # Then build the other request parts based on end strategy tool_responses: list[_messages.ModelRequestPart] = self._tool_responses @@ -555,10 +539,7 @@ async def _handle_text_response( text = '\n\n'.join(texts) try: - if output_schema is None or output_schema.allow_text_output == 'plain': - # The following cast is safe because we know `str` is an allowed result type - result_data = cast(NodeRunEndT, text) - elif output_schema.allow_text_output == 'json': + if output_schema.allow_text_output: run_context = build_run_context(ctx) result_data = await output_schema.process(text, run_context) else: @@ -659,7 +640,7 @@ async def process_function_tools( # noqa C901 yield event call_index_to_event_id[len(calls_to_run)] = event.call_id calls_to_run.append((mcp_tool, call)) - elif output_schema is not None and call.tool_name in output_schema.tools: + elif call.tool_name in output_schema.tools: # if tool_name is in output_schema, it means we found a output tool but an error occurred in # validation, we don't add another part here if output_tool_name is not None: @@ -788,8 +769,7 @@ def _unknown_tool( ) -> _messages.RetryPromptPart: ctx.state.increment_retries(ctx.deps.max_result_retries) tool_names = list(ctx.deps.function_tools.keys()) - if output_schema := ctx.deps.output_schema: - tool_names.extend(output_schema.tool_names()) + tool_names.extend(ctx.deps.output_schema.tool_names()) if tool_names: msg = f'Available tools: {", ".join(tool_names)}' diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index e9a5030bd..83a623c93 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -13,6 +13,8 @@ from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin +from pydantic_ai.profiles import ModelProfile + from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition @@ -208,7 +210,7 @@ def __init__( ) # TODO: Add `json_object` for old OpenAI models, or rename `json_schema` to `json` and choose automatically, relying on Pydantic validation -type OutputMode = Literal['tool', 'json_schema', 'manual_json'] +type OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'manual_json'] @dataclass @@ -218,50 +220,46 @@ class OutputSchema(Generic[OutputDataT]): Similar to `Tool` but for the final output of running an agent. """ - forced_mode: OutputMode | None - object_schema: OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] - tools: dict[str, OutputTool[OutputDataT]] - allow_text_output: Literal['plain', 'json'] | None = None + mode: OutputMode | None + object_schema: OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | None = None + tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) @classmethod def build( cls: type[OutputSchema[OutputDataT]], output_type: OutputType[OutputDataT], - name: str | None = None, - description: str | None = None, - strict: bool | None = None, - ) -> OutputSchema[OutputDataT] | None: + name: str | None, + description: str | None, + ) -> OutputSchema[OutputDataT]: """Build an OutputSchema dataclass from an output type.""" if output_type is str: - return None + return cls(mode='text') - forced_mode: OutputMode | None = None - allow_text_output: Literal['plain', 'json'] | None = 'plain' + mode: OutputMode | None = None tools: dict[str, OutputTool[OutputDataT]] = {} + strict: bool | None = None output_types: Sequence[OutputTypeOrFunction[OutputDataT]] if isinstance(output_type, JSONSchemaOutput): - forced_mode = 'json_schema' + mode = 'json_schema' output_types = output_type.output_types name = output_type.name # TODO: If not set, use method arg? description = output_type.description strict = output_type.strict - allow_text_output = 'json' elif isinstance(output_type, ManualJSONOutput): - forced_mode = 'manual_json' + mode = 'manual_json' output_types = output_type.output_types name = output_type.name description = output_type.description - allow_text_output = 'json' else: - # TODO: We can't always force tool mode here, because some models may not support tools but will work with manual_json output_types_or_tool_outputs = flatten_output_types(output_type) if str in output_types_or_tool_outputs: - forced_mode = 'tool' - allow_text_output = 'plain' - # TODO: What if str is the only item, e.g. `output_type=[str]` - output_types_or_tool_outputs = [t for t in output_types_or_tool_outputs if t is not str] + if len(output_types_or_tool_outputs) == 1: + return cls(mode='text') + else: + mode = 'tool_or_text' + output_types_or_tool_outputs = [t for t in output_types_or_tool_outputs if t is not str] multiple = len(output_types_or_tool_outputs) > 1 @@ -275,7 +273,9 @@ def build( tool_description = None tool_strict = None if isinstance(output_type_or_tool_output, ToolOutput): - forced_mode = 'tool' + if mode is None: + mode = 'tool' + tool_output = output_type_or_tool_output output_type = tool_output.output_type # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads @@ -307,7 +307,6 @@ def build( output_types.append(output_type) output_types = flatten_output_types(output_types) - if len(output_types) > 1: output_object_schema = OutputUnionSchema( output_types=output_types, name=name, description=description, strict=strict @@ -318,12 +317,30 @@ def build( ) return cls( - forced_mode=forced_mode, + mode=mode, object_schema=output_object_schema, tools=tools, - allow_text_output=allow_text_output, ) + @property + def allow_text_output(self) -> Literal['plain', 'json', False]: + """Whether the model allows text output.""" + if self.mode in ('text', 'tool_or_text'): + return 'plain' + elif self.mode in ('json_schema', 'manual_json'): + return 'json' + else: # tool-only mode + return False + + def is_mode_supported(self, profile: ModelProfile) -> bool: + """Whether the model supports the output mode.""" + mode = self.mode + if mode in ('text', 'manual_json'): + return True + if self.mode == 'tool_or_text': + mode = 'tool' + return mode in profile.output_modes + def find_named_tool( self, parts: Iterable[_messages.ModelResponsePart], tool_name: str ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: @@ -369,16 +386,18 @@ async def process( Returns: Either the validated output data (left) or a retry message (right). """ + assert self.allow_text_output is not False + + if self.allow_text_output == 'plain': + return cast(OutputDataT, data) + + assert self.object_schema is not None + return await self.object_schema.process( data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) -def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool: - # TODO: Add plain/json argument? - return output_schema is None or output_schema.allow_text_output is not None - - @dataclass class OutputObjectDefinition: name: str @@ -389,6 +408,7 @@ class OutputObjectDefinition: @property def manual_json_instructions(self) -> str: """Get instructions for model to output manual JSON matching the schema.""" + # TODO: Move to ModelProfile so it can be tweaked description = ': '.join([v for v in [self.name, self.description] if v]) return DEFAULT_MANUAL_JSON_PROMPT.format(schema=json.dumps(self.json_schema), description=description) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index abcb32818..25f1d80f7 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -14,6 +14,7 @@ from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated +from pydantic_ai.profiles import ModelProfile from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -140,7 +141,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) - _output_schema: _output.OutputSchema[OutputDataT] | None = dataclasses.field(repr=False) + _output_schema: _output.OutputSchema[OutputDataT] = dataclasses.field(repr=False) _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) _instructions: str | None = dataclasses.field(repr=False) _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) @@ -318,7 +319,9 @@ def __init__( output_retries = result_retries self._output_schema = _output.OutputSchema[OutputDataT].build( - output_type, self._deprecated_result_tool_name, self._deprecated_result_tool_description + output_type, + self._deprecated_result_tool_name, + self._deprecated_result_tool_description, ) self._output_validators = [] @@ -624,7 +627,7 @@ async def main(): deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 - output_schema = self._prepare_output_schema(output_type) + output_schema = self._prepare_output_schema(output_type, model_used.profile) output_type_ = output_type or self.output_type @@ -672,13 +675,18 @@ async def main(): ) async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: - if self._instructions is None and not self._instructions_functions: - return None + parts = [ + self._instructions, + *[await func.run(run_context) for func in self._instructions_functions], + ] - instructions = self._instructions or '' - for instructions_runner in self._instructions_functions: - instructions += '\n' + await instructions_runner.run(run_context) - return instructions.strip() + if output_schema.mode == 'manual_json' and (output_object_schema := output_schema.object_schema): + parts.append(output_object_schema.definition.manual_json_instructions) + + parts = [p for p in parts if p] + if not parts: + return None + return '\n\n'.join(parts).strip() graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( user_deps=deps, @@ -995,9 +1003,9 @@ async def stream_to_final( if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part if isinstance(new_part, _messages.TextPart): - if _output.allow_text_output(output_schema): + if output_schema.allow_text_output: return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart) and output_schema: + elif isinstance(new_part, _messages.ToolCallPart): for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) return None @@ -1553,8 +1561,8 @@ def _register_tool(self, tool: Tool[AgentDepsT]) -> None: if tool.name in self._function_tools: raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') - if self._output_schema and tool.name in self._output_schema.tools: - raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}') + if tool.name in self._output_schema.tools: + raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}') self._function_tools[tool.name] = tool @@ -1629,18 +1637,27 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') def _prepare_output_schema( - self, output_type: _output.OutputType[RunOutputDataT] | None - ) -> _output.OutputSchema[RunOutputDataT] | None: + self, output_type: _output.OutputType[RunOutputDataT] | None, model_profile: ModelProfile + ) -> _output.OutputSchema[RunOutputDataT]: if output_type is not None: if self._output_validators: raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators') - return _output.OutputSchema[RunOutputDataT].build( + schema = _output.OutputSchema[RunOutputDataT].build( output_type, self._deprecated_result_tool_name, self._deprecated_result_tool_description, ) else: - return self._output_schema # pyright: ignore[reportReturnType] + schema = self._output_schema + + if schema.mode is None: + schema.mode = model_profile.default_output_mode + if not schema.is_mode_supported(model_profile): + raise exceptions.UserError( + f"Output mode '{schema.mode}' is not among supported modes: {model_profile.output_modes}" + ) + + return schema # pyright: ignore[reportReturnType] @staticmethod def is_model_request_node( diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 1961938f2..93a822c43 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -294,12 +294,10 @@ class ModelRequestParameters: function_tools: list[ToolDefinition] = field(default_factory=list) - output_mode: OutputMode | None = None + output_mode: OutputMode = 'text' output_object: OutputObjectDefinition | None = None output_tools: list[ToolDefinition] = field(default_factory=list) - require_tool_use: bool = ( - True # TODO: Rename back to allow_text_output because this is public API. Support bool as well as plain/json - ) + allow_text_output: bool = True class Model(ABC): @@ -430,18 +428,6 @@ def _get_instructions(messages: list[ModelMessage]) -> str | None: return None - @property - def supported_output_modes(self) -> set[OutputMode]: - """The supported output modes for the model.""" - # TODO: Move to ModelProfile - return {'tool'} # TODO: Support manual_json on all - - @property - def default_output_mode(self) -> OutputMode: - """The default output mode for the model.""" - # TODO: Move to ModelProfile - return 'tool' - @dataclass class StreamedResponse(ABC): diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index e75cb4a18..35b8e8a52 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -213,7 +213,7 @@ async def _messages_create( if not tools: tool_choice = None else: - if model_request_parameters.require_tool_use: + if model_request_parameters.output_mode == 'tool': tool_choice = {'type': 'any'} else: tool_choice = {'type': 'auto'} diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 6e4e4797e..3ed16e726 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -378,7 +378,7 @@ def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> return None tool_choice: ToolChoiceTypeDef - if model_request_parameters.require_tool_use: + if model_request_parameters.output_mode == 'tool': tool_choice = {'any': {}} else: tool_choice = {'auto': {}} diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 934707cce..22bcddffb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -91,7 +91,7 @@ async def request( ) -> ModelResponse: agent_info = AgentInfo( model_request_parameters.function_tools, - not model_request_parameters.require_tool_use, + model_request_parameters.allow_text_output, model_request_parameters.output_tools, model_settings, ) @@ -120,7 +120,7 @@ async def request_stream( ) -> AsyncIterator[StreamedResponse]: agent_info = AgentInfo( model_request_parameters.function_tools, - not model_request_parameters.require_tool_use, + model_request_parameters.allow_text_output, model_request_parameters.output_tools, model_settings, ) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index aa8c23eed..08c83e539 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -192,7 +192,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: - if not model_request_parameters.require_tool_use: + if model_request_parameters.output_mode != 'tool': return None elif tools: return _tool_config([t['name'] for t in tools['function_declarations']]) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 38799712e..faeb2d13c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -213,7 +213,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None ) -> ToolConfigDict | None: - if not model_request_parameters.require_tool_use: + if model_request_parameters.output_mode != 'tool': return None elif tools: names: list[str] = [] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index c6dfed3bd..b9c8ef54c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -209,7 +209,7 @@ async def _completions_create( # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif model_request_parameters.require_tool_use: + elif model_request_parameters.output_mode == 'tool': tool_choice = 'required' else: tool_choice = 'auto' diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0ff134a91..866ccb58b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -287,7 +287,7 @@ def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> """ if not model_request_parameters.function_tools and not model_request_parameters.output_tools: return None - elif model_request_parameters.require_tool_use: + elif model_request_parameters.output_mode == 'tool': return 'required' else: return 'auto' diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 1f0675123..f5404579c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -14,7 +14,7 @@ from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage -from .._output import OutputMode, OutputObjectDefinition +from .._output import OutputObjectDefinition from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( AudioUrl, @@ -238,11 +238,6 @@ def system(self) -> str: """The system / model provider.""" return self._system - @property - def supported_output_modes(self) -> set[OutputMode]: - """The supported output modes for the model.""" - return {'tool', 'json_schema', 'manual_json'} - @overload async def _completions_create( self, @@ -273,22 +268,17 @@ async def _completions_create( tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] response_format: chat.completion_create_params.ResponseFormat | None = None - if model_request_parameters.output_mode == 'tool': + output_mode = model_request_parameters.output_mode + if output_mode == 'tool': tools.extend(self._map_tool_definition(r) for r in model_request_parameters.output_tools) - elif output_object := model_request_parameters.output_object: - if model_request_parameters.output_mode == 'json_schema': - response_format = self._map_output_object_definition(output_object) - elif model_request_parameters.output_mode == 'manual_json': - openai_messages.insert( - 0, - chat.ChatCompletionSystemMessageParam( - content=output_object.manual_json_instructions, role='system' - ), - ) + elif output_mode == 'json_schema': + output_object = model_request_parameters.output_object + assert output_object is not None + response_format = self._map_output_object_definition(output_object) if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif model_request_parameters.require_tool_use: + elif model_request_parameters.output_mode == 'tool': tool_choice = 'required' else: tool_choice = 'auto' @@ -599,11 +589,6 @@ def system(self) -> str: """The system / model provider.""" return self._system - @property - def supported_output_modes(self) -> set[OutputMode]: - """The supported output modes for the model.""" - return {'tool', 'json_schema'} - async def request( self, messages: list[ModelRequest | ModelResponse], @@ -687,7 +672,7 @@ async def _responses_create( # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif model_request_parameters.require_tool_use: + elif model_request_parameters.output_mode == 'tool': tool_choice = 'required' else: tool_choice = 'auto' @@ -896,7 +881,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content is not None: - # TODO: Handle structured output yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) for dtc in choice.delta.tool_calls or []: @@ -978,7 +962,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: pass elif isinstance(chunk, responses.ResponseTextDeltaEvent): - # TODO: Handle structured output yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta) elif isinstance(chunk, responses.ResponseTextDoneEvent): diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 9079afa5e..0daad25bc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -130,7 +130,7 @@ def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> l def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput: if self.custom_output_text is not None: - assert not model_request_parameters.require_tool_use, ( + assert model_request_parameters.allow_text_output, ( 'Plain response not allowed, but `custom_output_text` is set.' ) assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.' @@ -145,7 +145,7 @@ def _get_output(self, model_request_parameters: ModelRequestParameters) -> _Wrap return _WrappedToolOutput({k: self.custom_output_args}) else: return _WrappedToolOutput(self.custom_output_args) - elif not model_request_parameters.require_tool_use: + elif model_request_parameters.allow_text_output: return _WrappedTextOutput(None) elif model_request_parameters.output_tools: return _WrappedToolOutput(None) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 3792c95a6..743442c66 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations -from dataclasses import dataclass, fields, replace -from typing import Callable, Union +from dataclasses import dataclass, field, fields, replace +from typing import Callable, Literal, Union from typing_extensions import Self @@ -13,6 +13,9 @@ class ModelProfile: """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used.""" json_schema_transformer: type[JsonSchemaTransformer] | None = None + output_modes: set[Literal['tool', 'json_schema']] = field(default_factory=lambda: {'tool'}) + # TODO: Add docstrings + default_output_mode: Literal['tool', 'json_schema', 'manual_json'] = 'tool' @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index 5bdac9f6d..a708000fe 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -21,7 +21,7 @@ class OpenAIModelProfile(ModelProfile): def openai_model_profile(model_name: str) -> ModelProfile: """Get the model profile for an OpenAI model.""" - return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer) + return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer, output_modes={'tool', 'json_schema'}) _STRICT_INCOMPATIBLE_KEYS = [ diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 6c1007266..de535aba0 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,11 +5,11 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import Generic, cast +from typing import Generic from typing_extensions import TypeVar, assert_type, deprecated, overload -from . import _output, _utils, exceptions, messages as _messages, models +from . import _utils, exceptions, messages as _messages, models from ._output import ( JSONSchemaOutput, ManualJSONOutput, @@ -34,7 +34,7 @@ @dataclass class AgentStream(Generic[AgentDepsT, OutputDataT]): _raw_stream_response: models.StreamedResponse - _output_schema: OutputSchema[OutputDataT] | None + _output_schema: OutputSchema[OutputDataT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None @@ -82,7 +82,7 @@ async def _validate_response( ) -> OutputDataT: """Validate a structured result message.""" call = None - if self._output_schema is not None and output_tool_name is not None: + if output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -96,13 +96,9 @@ async def _validate_response( else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - if self._output_schema is None or self._output_schema.allow_text_output == 'plain': - # The following cast is safe because we know `str` is an allowed output type - result_data = cast(OutputDataT, text) - else: - result_data = await self._output_schema.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False - ) + result_data = await self._output_schema.process( + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) @@ -126,12 +122,9 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. if isinstance(e, _messages.PartStartEvent): new_part = e.part if isinstance(new_part, _messages.ToolCallPart): - if output_schema: - for call, _ in output_schema.find_tool([new_part]): # pragma: no branch - return _messages.FinalResultEvent( - tool_name=call.tool_name, tool_call_id=call.tool_call_id - ) - elif _output.allow_text_output(output_schema): # pragma: no branch + for call, _ in output_schema.find_tool([new_part]): # pragma: no branch + return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id) + elif output_schema.allow_text_output: # pragma: no branch assert_type(e, _messages.PartStartEvent) return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) @@ -163,7 +156,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _usage_limits: UsageLimits | None _stream_response: models.StreamedResponse - _output_schema: OutputSchema[OutputDataT] | None + _output_schema: OutputSchema[OutputDataT] _run_ctx: RunContext[AgentDepsT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None @@ -319,7 +312,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ - if self._output_schema and self._output_schema.allow_text_output != 'plain': + if self._output_schema.allow_text_output != 'plain': raise exceptions.UserError('stream_text() can only be used with text responses') if delta: @@ -398,7 +391,7 @@ async def validate_structured_output( ) -> OutputDataT: """Validate a structured result message.""" call = None - if self._output_schema is not None and self._output_tool_name is not None: + if self._output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -412,13 +405,9 @@ async def validate_structured_output( else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - if self._output_schema is None or self._output_schema.allow_text_output == 'plain': - # The following cast is safe because we know `str` is an allowed output type - result_data = cast(OutputDataT, text) - else: - result_data = await self._output_schema.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False - ) + result_data = await self._output_schema.process( + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 35578309c..78f8f4e85 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -684,7 +684,7 @@ async def test_bedrock_anthropic_no_tool_choice(bedrock_provider: BedrockProvide 'This is my tool', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, ) - mrp = ModelRequestParameters(function_tools=[my_tool], require_tool_use=True, output_tools=[]) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=False, output_tools=[]) # Models other than Anthropic support tool_choice model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index ad6c1bab1..c88876832 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -127,7 +127,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:success_response:', 'gen_ai.system': 'function', @@ -200,7 +200,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function::failure_response_stream,function::success_response_stream', 'gen_ai.system': 'function', @@ -272,7 +272,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'fallback:function,function', 'gen_ai.request.model': 'fallback:function:failure_response:,function:failure_response:', - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:failure_response:', diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ca26dd408..7a022250b 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -64,7 +64,7 @@ async def test_model_simple(allow_model_requests: None): assert 'x-goog-api-key' in m.client.headers mrp = ModelRequestParameters( - function_tools=[], require_tool_use=False, output_tools=[], output_mode=None, output_object=None + function_tools=[], allow_text_output=True, output_tools=[], output_mode='text', output_object=None ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) @@ -100,9 +100,9 @@ async def test_model_tools(allow_model_requests: None): mrp = ModelRequestParameters( function_tools=tools, - require_tool_use=False, + allow_text_output=True, output_tools=[output_tool], - output_mode=None, + output_mode='text', output_object=None, ) mrp = m.customize_request_parameters(mrp) @@ -148,9 +148,9 @@ async def test_require_response_tool(allow_model_requests: None): ) mrp = ModelRequestParameters( function_tools=[], - require_tool_use=True, + allow_text_output=False, output_tools=[output_tool], - output_mode=None, + output_mode='tool', output_object=None, ) mrp = m.customize_request_parameters(mrp) @@ -235,9 +235,9 @@ class Locations(BaseModel): ) mrp = ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[output_tool], - output_mode=None, + output_mode='text', output_object=None, ) mrp = m.customize_request_parameters(mrp) @@ -320,8 +320,8 @@ class QueryDetails(BaseModel): ) mrp = ModelRequestParameters( function_tools=[], - output_mode=None, - require_tool_use=False, + output_mode='text', + allow_text_output=True, output_tools=[output_tool], output_object=None, ) @@ -368,9 +368,9 @@ class Locations(BaseModel): ) mrp = ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[output_tool], - output_mode=None, + output_mode='text', output_object=None, ) mrp = m.customize_request_parameters(mrp) @@ -438,9 +438,9 @@ class Location(BaseModel): with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'): mrp = ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[output_tool], - output_mode=None, + output_mode='text', output_object=None, ) mrp = m.customize_request_parameters(mrp) @@ -476,9 +476,9 @@ class FormattedStringFields(BaseModel): ) mrp = ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[output_tool], - output_mode=None, + output_mode='text', output_object=None, ) mrp = m.customize_request_parameters(mrp) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 5bb8204db..1e17c5469 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -132,9 +132,9 @@ async def test_instrumented_model(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[], - output_mode=None, + output_mode='text', output_object=None, ), ) @@ -153,7 +153,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -330,9 +330,9 @@ async def test_instrumented_model_not_recording(): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[], - output_mode=None, + output_mode='text', output_object=None, ), ) @@ -354,9 +354,9 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[], - output_mode=None, + output_mode='text', output_object=None, ), ) as response_stream: @@ -381,7 +381,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -444,9 +444,9 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[], - output_mode=None, + output_mode='text', output_object=None, ), ) as response_stream: @@ -468,7 +468,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -549,9 +549,9 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): model_settings=ModelSettings(temperature=1), model_request_parameters=ModelRequestParameters( function_tools=[], - require_tool_use=False, + allow_text_output=True, output_tools=[], - output_mode=None, + output_mode='text', output_object=None, ), ) @@ -570,7 +570,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', + 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "allow_text_output": true}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', 'logfire.span_type': 'span', diff --git a/tests/models/test_model_request_parameters.py b/tests/models/test_model_request_parameters.py index 5a0918211..98a6d1ccc 100644 --- a/tests/models/test_model_request_parameters.py +++ b/tests/models/test_model_request_parameters.py @@ -5,12 +5,12 @@ def test_model_request_parameters_are_serializable(): params = ModelRequestParameters( - function_tools=[], output_mode=None, require_tool_use=False, output_tools=[], output_object=None + function_tools=[], output_mode='text', allow_text_output=True, output_tools=[], output_object=None ) assert TypeAdapter(ModelRequestParameters).dump_python(params) == { 'function_tools': [], - 'preferred_output_mode': None, - 'require_tool_use': False, + 'output_mode': 'text', + 'allow_text_output': True, 'output_tools': [], 'output_object': None, } diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index d8702e384..f606b4838 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -2006,7 +2006,16 @@ async def get_user_country() -> str: content='What is the largest city in the user country?', timestamp=IsDatetime(), ) - ] + ], + instructions="""\ +Always respond with a JSON object matching this description and schema: + +CityLocation + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", ), ModelResponse( parts=[ @@ -2037,7 +2046,16 @@ async def get_user_country() -> str: tool_call_id='call_uTjt2vMkeTr0GYqQyQYrUUhl', timestamp=IsDatetime(), ) - ] + ], + instructions="""\ +Always respond with a JSON object matching this description and schema: + +CityLocation + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", ), ModelResponse( parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], diff --git a/tests/test_agent.py b/tests/test_agent.py index 6be9f8e30..848878b55 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -354,14 +354,14 @@ def test_response_tuple(): m = TestModel() agent = Agent(m, output_type=tuple[str, str]) - assert agent._output_schema.allow_text_output is None # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + assert agent._output_schema.allow_text_output is False # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert result.output == snapshot(('a', 'a')) assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.require_tool_use is True + assert m.last_model_request_parameters.allow_text_output is False assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 1 @@ -411,7 +411,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: got_tool_call_name = ctx.tool_name return o - assert agent._output_schema.allow_text_output == 'plain' # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + assert agent._output_schema.allow_text_output == 'plain' # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert result.output == snapshot('success (no tool calls)') @@ -419,7 +419,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.require_tool_use is False + assert m.last_model_request_parameters.allow_text_output is True assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 1 @@ -497,7 +497,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert m.last_model_request_parameters is not None assert m.last_model_request_parameters.function_tools == snapshot([]) - assert m.last_model_request_parameters.require_tool_use is True + assert m.last_model_request_parameters.allow_text_output is False assert m.last_model_request_parameters.output_tools is not None assert len(m.last_model_request_parameters.output_tools) == 2 diff --git a/tests/test_direct.py b/tests/test_direct.py index 16c2b0bd2..46f409ef0 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -51,7 +51,7 @@ async def test_model_request_tool_call(): function_tools=[ ToolDefinition(name='tool_name', description='', parameters_json_schema={'type': 'object'}) ], - require_tool_use=True, + allow_text_output=False, ), ) assert model_response == snapshot( diff --git a/tests/test_logfire.py b/tests/test_logfire.py index eff2c62ba..92a99c48c 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -223,10 +223,10 @@ async def my_ret(x: int) -> str: 'strict': None, } ], - 'output_mode': None, + 'output_mode': 'text', 'output_tools': [], 'output_object': None, - 'require_tool_use': False, + 'allow_text_output': True, } ) ), @@ -406,7 +406,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'test', 'gen_ai.request.model': 'test', - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "require_tool_use": false}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.span_type': 'span', 'logfire.msg': 'chat test', 'gen_ai.usage.input_tokens': 51, diff --git a/tests/test_tools.py b/tests/test_tools.py index 218c564e4..d92a7af3c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -571,7 +571,7 @@ def test_tool_return_conflict(): # this is also okay Agent('test', tools=[ctx_tool], deps_type=int, output_type=int) # this raises an error - with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"): + with pytest.raises(UserError, match="Tool name conflicts with output tool name: 'ctx_tool'"): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) From 0cb25c478855a69b7868b7002ff4c2ed27504c93 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 3 Jun 2025 05:29:11 +0000 Subject: [PATCH 03/90] Fix tests --- pydantic_ai_slim/pydantic_ai/_output.py | 12 +- ...st_openai_manual_json_output_multiple.yaml | 209 ++++++++++++++++++ tests/models/test_openai.py | 110 +++++++++ tests/test_agent.py | 1 + 4 files changed, 327 insertions(+), 5 deletions(-) create mode 100644 tests/models/cassettes/test_openai/test_openai_manual_json_output_multiple.yaml diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 83a623c93..ed97d38a4 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -210,7 +210,7 @@ def __init__( ) # TODO: Add `json_object` for old OpenAI models, or rename `json_schema` to `json` and choose automatically, relying on Pydantic validation -type OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'manual_json'] +OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'manual_json'] @dataclass @@ -370,7 +370,7 @@ def tool_defs(self) -> list[ToolDefinition]: async def process( self, - data: str | dict[str, Any], + text: str, run_context: RunContext[AgentDepsT], allow_partial: bool = False, wrap_validation_errors: bool = True, @@ -378,7 +378,7 @@ async def process( """Validate an output message. Args: - data: The output data to validate. + text: The output text to validate. run_context: The current run context. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -389,12 +389,14 @@ async def process( assert self.allow_text_output is not False if self.allow_text_output == 'plain': - return cast(OutputDataT, data) + return cast(OutputDataT, text) + # TODO: Always give this some value so we can drop some checks/asserts assert self.object_schema is not None + # TODO: Strip Markdown fences? return await self.object_schema.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) diff --git a/tests/models/cassettes/test_openai/test_openai_manual_json_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_manual_json_output_multiple.yaml new file mode 100644 index 000000000..1c143bc6c --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_manual_json_output_multiple.yaml @@ -0,0 +1,209 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1438' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object matching this description and schema: + + final_result: The final response which ends this conversation + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '549' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_W3kwVF2ZX9cZ2L9NnbSDSs3V + type: function + created: 1748928408 + id: chatcmpl-BeESepFYXE1ELAEvKlRNvsRzzYkOg + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 284 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 295 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1669' + content-type: + - application/json + cookie: + - __cf_bm=bTNTDMnAXJaO9CpT6TAFSQOHtSGvSowNKQi20qZpxKU-1748928408-1.0.1.1-RtnC2hZt1TL38SpAb3dXL5bL7Q7EuNYgc0.18VudHVT7WzCWgkYNSYscp6aLzd9yCQaDu__K1Q2tP05IZqU3y2KwNh0JTcuf.o8GEEG_kbY; + _cfuvid=6ycbGPnBJWGLcM9mdMSLys2eUHapvWyYqbWltcaQhkQ-1748928408894-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object matching this description and schema: + + final_result: The final response which ends this conversation + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_W3kwVF2ZX9cZ2L9NnbSDSs3V + type: function + - content: Mexico + role: tool + tool_call_id: call_W3kwVF2ZX9cZ2L9NnbSDSs3V + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '903' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '900' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + refusal: null + role: assistant + created: 1748928409 + id: chatcmpl-BeESfsdYDCQwP7kGM4r5i8KXwllkT + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 21 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 305 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 326 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index c41c9a4ec..2a99e6e6b 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -2100,3 +2100,113 @@ async def get_user_country() -> str: ), ] ) + + +@pytest.mark.vcr() +async def test_openai_manual_json_output_multiple(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + # TODO: Test with functions! + agent = Agent(m, output_type=ManualJSONOutput([CityLocation, CountryLanguage])) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + # TODO: Show what response_format looks like + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object matching this description and schema: + +final_result: The final response which ends this conversation + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_W3kwVF2ZX9cZ2L9NnbSDSs3V') + ], + usage=Usage( + requests=1, + request_tokens=284, + response_tokens=11, + total_tokens=295, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BeESepFYXE1ELAEvKlRNvsRzzYkOg', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_W3kwVF2ZX9cZ2L9NnbSDSs3V', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object matching this description and schema: + +final_result: The final response which ends this conversation + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=305, + response_tokens=21, + total_tokens=326, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BeESfsdYDCQwP7kGM4r5i8KXwllkT', + ), + ] + ) diff --git a/tests/test_agent.py b/tests/test_agent.py index 848878b55..ef004f159 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2487,6 +2487,7 @@ def instructions() -> str: parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], instructions="""\ You are a helpful assistant. + You are a potato.\ """, ) From 933b74e764ce0f23890b6fc4152e86e19473e621 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 3 Jun 2025 05:32:01 +0000 Subject: [PATCH 04/90] Remove syntax invalid before Python 3.12 --- pydantic_ai_slim/pydantic_ai/_output.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ed97d38a4..9d4f30c70 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -693,9 +693,7 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return () -def flatten_output_types[T]( - output_type: T | Sequence[T], -) -> list[T]: +def flatten_output_types(output_type: T | Sequence[T]) -> list[T]: output_types: Sequence[T] if isinstance(output_type, Sequence): output_types = output_type From 7974df065bc211d32233e95156168f2fc0b24137 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 3 Jun 2025 05:36:56 +0000 Subject: [PATCH 05/90] Fix tests --- pydantic_ai_slim/pydantic_ai/agent.py | 2 ++ tests/models/test_instrumented.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 25f1d80f7..a9a843a92 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1651,6 +1651,8 @@ def _prepare_output_schema( schema = self._output_schema if schema.mode is None: + # TODO: This may need to be done later, when we know if there are any model_request_parameters.function_tools, + # as some models do not support tool calls at the same time as json_schema output, and which mode we pick may be different... schema.mode = model_profile.default_output_mode if not schema.is_mode_supported(model_profile): raise exceptions.UserError( diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 1e17c5469..ac5383663 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -153,7 +153,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "allow_text_output": true}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -381,7 +381,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "allow_text_output": true}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -468,7 +468,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "allow_text_output": true}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -570,7 +570,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "output_mode": null, "output_object": null, "output_tools": [], "allow_text_output": true}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', 'logfire.span_type': 'span', From 9cc19e22e0300650c1307684088c926ffe5250ca Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 9 Jun 2025 21:08:22 +0000 Subject: [PATCH 06/90] Add TextOutput marker --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/_output.py | 317 ++++++++++++------- pydantic_ai_slim/pydantic_ai/agent.py | 20 +- tests/models/test_openai.py | 81 ++++- tests/test_agent.py | 115 ++++++- tests/typed_agent.py | 32 +- 6 files changed, 430 insertions(+), 137 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e6fdc7224..0a7236a43 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -269,7 +269,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None: return models.ModelRequestParameters( function_tools=function_tool_defs, output_mode=output_schema.mode, - output_object=output_schema.object_schema.definition if output_schema.object_schema else None, + output_object=output_schema.text_output_schema.object_def if output_schema.text_output_schema else None, output_tools=output_schema.tool_defs(), allow_text_output=output_schema.allow_text_output == 'plain', ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 9d4f30c70..a72403218 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -16,7 +16,7 @@ from pydantic_ai.profiles import ModelProfile from . import _function_schema, _utils, messages as _messages -from .exceptions import ModelRetry +from .exceptions import ModelRetry, UserError from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition T = TypeVar('T') @@ -150,6 +150,16 @@ def __init__( self.strict = strict +@dataclass +class TextOutput(Generic[OutputDataT]): + """Marker class to use text output for outputs.""" + + output_type: ( + Callable[[RunContext, str], Awaitable[OutputDataT] | OutputDataT] + | Callable[[str], Awaitable[OutputDataT] | OutputDataT] + ) + + @dataclass(init=False) class JSONSchemaOutput(Generic[OutputDataT]): """Marker class to use JSON schema output for outputs.""" @@ -195,14 +205,15 @@ def __init__( T_co = TypeVar('T_co', covariant=True) OutputTypeOrFunction = TypeAliasType( - 'OutputTypeOrFunction', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,) + 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Awaitable[T_co] | T_co]], type_params=(T_co,) ) OutputType = TypeAliasType( 'OutputType', Union[ OutputTypeOrFunction[T_co], ToolOutput[T_co], - Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co]]], + TextOutput[T_co], + Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], JSONSchemaOutput[T_co], ManualJSONOutput[T_co], ], @@ -213,124 +224,160 @@ def __init__( OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'manual_json'] -@dataclass +@dataclass(init=False) class OutputSchema(Generic[OutputDataT]): """Model the final output from an agent run. Similar to `Tool` but for the final output of running an agent. """ - mode: OutputMode | None - object_schema: OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | None = None + mode: OutputMode | None = None + text_output_schema: ( + OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | OutputTextSchema[OutputDataT] | None + ) = None tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) - @classmethod - def build( - cls: type[OutputSchema[OutputDataT]], + def __init__( + self, output_type: OutputType[OutputDataT], - name: str | None, - description: str | None, - ) -> OutputSchema[OutputDataT]: + *, + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ): """Build an OutputSchema dataclass from an output type.""" - if output_type is str: - return cls(mode='text') + self.mode = None + self.text_output_schema = None + self.tools = {} - mode: OutputMode | None = None - tools: dict[str, OutputTool[OutputDataT]] = {} - strict: bool | None = None + if output_type is str: + self.mode = 'text' + self.text_output_schema = OutputTextSchema(output_type) + return - output_types: Sequence[OutputTypeOrFunction[OutputDataT]] if isinstance(output_type, JSONSchemaOutput): - mode = 'json_schema' - output_types = output_type.output_types - name = output_type.name # TODO: If not set, use method arg? - description = output_type.description - strict = output_type.strict - elif isinstance(output_type, ManualJSONOutput): - mode = 'manual_json' - output_types = output_type.output_types - name = output_type.name - description = output_type.description - else: - output_types_or_tool_outputs = flatten_output_types(output_type) - - if str in output_types_or_tool_outputs: - if len(output_types_or_tool_outputs) == 1: - return cls(mode='text') - else: - mode = 'tool_or_text' - output_types_or_tool_outputs = [t for t in output_types_or_tool_outputs if t is not str] - - multiple = len(output_types_or_tool_outputs) > 1 - - default_tool_name = name or DEFAULT_OUTPUT_TOOL_NAME - default_tool_description = description - default_tool_strict = strict - - output_types = [] - for output_type_or_tool_output in output_types_or_tool_outputs: - tool_name = None - tool_description = None - tool_strict = None - if isinstance(output_type_or_tool_output, ToolOutput): - if mode is None: - mode = 'tool' - - tool_output = output_type_or_tool_output - output_type = tool_output.output_type - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads - tool_name = tool_output.name - tool_description = tool_output.description - tool_strict = tool_output.strict - else: - output_type = output_type_or_tool_output + self.mode = 'json_schema' + self.text_output_schema = self._build_text_output_schema( + output_type.output_types, + name=output_type.name, + description=output_type.description, + strict=output_type.strict, + ) + return - if tool_name is None: - tool_name = default_tool_name - if multiple: - tool_name += f'_{output_type.__name__}' + if isinstance(output_type, ManualJSONOutput): + self.mode = 'manual_json' + self.text_output_schema = self._build_text_output_schema( + output_type.output_types, name=output_type.name, description=output_type.description + ) + return + + text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] + tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] + other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = [] + for output_type_or_marker in flatten_output_types(output_type): + if output_type_or_marker is str: + text_outputs.append(cast(type[str], output_type_or_marker)) + elif isinstance(output_type_or_marker, TextOutput): + text_outputs.append(output_type_or_marker) + elif isinstance(output_type_or_marker, ToolOutput): + tool_outputs.append(output_type_or_marker) + else: + other_outputs.append(output_type_or_marker) - i = 1 - original_tool_name = tool_name - while tool_name in tools: - i += 1 - tool_name = f'{original_tool_name}_{i}' + self.tools = self._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) - tool_description = tool_description or default_tool_description - if tool_strict is None: - tool_strict = default_tool_strict + if len(text_outputs) > 0: + if len(text_outputs) > 1: + raise UserError('Only one text output is allowed') + text_output = text_outputs[0] - parameters_schema = OutputObjectSchema( - output_type=output_type, description=tool_description, strict=tool_strict - ) - tools[tool_name] = OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple) - output_types.append(output_type) + self.mode = 'text' + if len(self.tools) > 0: + self.mode = 'tool_or_text' - output_types = flatten_output_types(output_types) - if len(output_types) > 1: - output_object_schema = OutputUnionSchema( - output_types=output_types, name=name, description=description, strict=strict - ) + if isinstance(text_output, TextOutput): + self.text_output_schema = OutputTextSchema(text_output.output_type) + elif text_output is str: + self.text_output_schema = cast(OutputTextSchema[OutputDataT], OutputTextSchema(text_output)) + elif len(tool_outputs) > 0: + self.mode = 'tool' else: - output_object_schema = OutputObjectSchema( - output_type=output_types[0], name=name, description=description, strict=strict + self.text_output_schema = self._build_text_output_schema( + other_outputs, name=name, description=description, strict=strict ) - return cls( - mode=mode, - object_schema=output_object_schema, - tools=tools, - ) + @staticmethod + def _build_tools( + outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> dict[str, OutputTool[OutputDataT]]: + tools: dict[str, OutputTool[OutputDataT]] = {} + + default_name = name or DEFAULT_OUTPUT_TOOL_NAME + default_description = description + default_strict = strict + + multiple = len(outputs) > 1 + for output in outputs: + name = None + description = None + strict = None + if isinstance(output, ToolOutput): + output_type = output.output_type + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads + name = output.name + description = output.description + strict = output.strict + else: + output_type = output + + if name is None: + name = default_name + if multiple: + name += f'_{output_type.__name__}' + + i = 1 + original_name = name + while name in tools: + i += 1 + name = f'{original_name}_{i}' + + description = description or default_description + if strict is None: + strict = default_strict + + parameters_schema = OutputObjectSchema(output_type=output_type, description=description, strict=strict) + tools[name] = OutputTool(name=name, parameters_schema=parameters_schema, multiple=multiple) + + return tools + + @staticmethod + def _build_text_output_schema( + outputs: Sequence[OutputTypeOrFunction[OutputDataT]], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | None: + if len(outputs) == 0: + return None + + outputs = flatten_output_types(outputs) + if len(outputs) == 1: + return OutputObjectSchema(output_type=outputs[0], name=name, description=description, strict=strict) + + return OutputUnionSchema(output_types=outputs, name=name, description=description, strict=strict) @property def allow_text_output(self) -> Literal['plain', 'json', False]: """Whether the model allows text output.""" + if self.mode == 'tool': + return False if self.mode in ('text', 'tool_or_text'): return 'plain' - elif self.mode in ('json_schema', 'manual_json'): - return 'json' - else: # tool-only mode - return False + return 'json' def is_mode_supported(self, profile: ModelProfile) -> bool: """Whether the model supports the output mode.""" @@ -387,15 +434,10 @@ async def process( Either the validated output data (left) or a retry message (right). """ assert self.allow_text_output is not False - - if self.allow_text_output == 'plain': - return cast(OutputDataT, text) - - # TODO: Always give this some value so we can drop some checks/asserts - assert self.object_schema is not None + assert self.text_output_schema is not None # TODO: Strip Markdown fences? - return await self.object_schema.process( + return await self.text_output_schema.process( text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -429,8 +471,7 @@ class OutputUnionData: # TODO: Better class naming @dataclass(init=False) class OutputUnionSchema(Generic[OutputDataT]): - definition: OutputObjectDefinition - outer_typed_dict_key: str = 'result' + object_def: OutputObjectDefinition _root_object_schema: OutputObjectSchema[OutputUnionData] _object_schemas: dict[str, OutputObjectSchema[OutputDataT]] @@ -461,9 +502,9 @@ def __init__( 'kind': { 'const': name, }, - 'data': object_schema.definition.json_schema, # TODO: Pop description here? + 'data': object_schema.object_def.json_schema, # TODO: Pop description here? }, - 'description': object_schema.definition.description or name, # TODO: Better description + 'description': object_schema.object_def.description or name, # TODO: Better description 'required': ['kind', 'data'], 'additionalProperties': False, } @@ -475,7 +516,7 @@ def __init__( 'additionalProperties': False, } - self.definition = OutputObjectDefinition( + self.object_def = OutputObjectDefinition( name=name or DEFAULT_OUTPUT_TOOL_NAME, description=description or DEFAULT_OUTPUT_TOOL_DESCRIPTION, json_schema=json_schema, @@ -509,15 +550,15 @@ async def process( @dataclass(init=False) class OutputObjectSchema(Generic[OutputDataT]): - definition: OutputObjectDefinition + object_def: OutputObjectDefinition outer_typed_dict_key: str | None = None _validator: SchemaValidator _function_schema: _function_schema.FunctionSchema | None = None def __init__( self, - *, output_type: OutputTypeOrFunction[OutputDataT], + *, name: str | None = None, description: str | None = None, strict: bool | None = None, @@ -555,7 +596,7 @@ def __init__( else: description = f'{description}. {json_schema_description}' - self.definition = OutputObjectDefinition( + self.object_def = OutputObjectDefinition( name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME), description=description, json_schema=json_schema, @@ -595,6 +636,9 @@ async def process( else: raise + if k := self.outer_typed_dict_key: + output = output[k] + if self._function_schema: try: output = await self._function_schema.call(output, run_context) @@ -607,11 +651,60 @@ async def process( else: raise - if k := self.outer_typed_dict_key: - output = output[k] return output +@dataclass(init=False) +class OutputTextSchema(Generic[OutputDataT]): + _function_schema: _function_schema.FunctionSchema | None = None + _str_argument_name: str | None = None + + def __init__( + self, + output_type: type[OutputDataT] + | Callable[[RunContext[AgentDepsT], str], Awaitable[OutputDataT] | OutputDataT] + | Callable[[str], Awaitable[OutputDataT] | OutputDataT] = str, + ): + if inspect.isfunction(output_type) or inspect.ismethod(output_type): + self._function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) + arguments_schema = self._function_schema.json_schema.get('properties', {}) + argument_name = next(iter(arguments_schema.keys()), None) + if argument_name and arguments_schema.get(argument_name, {}).get('type') == 'string': + self._str_argument_name = argument_name + return + elif output_type is str: + return + + raise ValueError('OutputTextSchema must take the `str` type or a function taking a `str`') + + @property + def object_def(self) -> None: + return None + + async def process( + self, + data: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + output = data + + if self._function_schema and self._str_argument_name: + try: + output = await self._function_schema.call({self._str_argument_name: output}, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise + + return cast(OutputDataT, output) + + @dataclass(init=False) class OutputTool(Generic[OutputDataT]): parameters_schema: OutputObjectSchema[OutputDataT] @@ -619,7 +712,7 @@ class OutputTool(Generic[OutputDataT]): def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool): self.parameters_schema = parameters_schema - definition = parameters_schema.definition + definition = parameters_schema.object_def description = definition.description if not description: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a9a843a92..7f758ec5d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -318,10 +318,10 @@ def __init__( warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning) output_retries = result_retries - self._output_schema = _output.OutputSchema[OutputDataT].build( + self._output_schema = _output.OutputSchema[OutputDataT]( output_type, - self._deprecated_result_tool_name, - self._deprecated_result_tool_description, + name=self._deprecated_result_tool_name, + description=self._deprecated_result_tool_description, ) self._output_validators = [] @@ -680,8 +680,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: *[await func.run(run_context) for func in self._instructions_functions], ] - if output_schema.mode == 'manual_json' and (output_object_schema := output_schema.object_schema): - parts.append(output_object_schema.definition.manual_json_instructions) + if ( + output_schema.mode == 'manual_json' + and (output_object_schema := output_schema.text_output_schema) + and (object_def := output_object_schema.object_def) + ): + parts.append(object_def.manual_json_instructions) parts = [p for p in parts if p] if not parts: @@ -1642,10 +1646,10 @@ def _prepare_output_schema( if output_type is not None: if self._output_validators: raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators') - schema = _output.OutputSchema[RunOutputDataT].build( + schema = _output.OutputSchema[RunOutputDataT]( output_type, - self._deprecated_result_tool_name, - self._deprecated_result_tool_description, + name=self._deprecated_result_tool_name, + description=self._deprecated_result_tool_description, ) else: schema = self._output_schema diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 2a99e6e6b..5f4e7ece1 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -15,7 +15,7 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior -from pydantic_ai._output import ManualJSONOutput +from pydantic_ai._output import ManualJSONOutput, TextOutput from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -1830,6 +1830,85 @@ async def get_user_country() -> str: ) +async def test_openai_text_output_function(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot('THE LARGEST CITY IN MEXICO IS MEXICO CITY.') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_bZpN3tcL7reJvaSfcJhWIUaj') + ], + usage=Usage( + requests=1, + request_tokens=42, + response_tokens=11, + total_tokens=53, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-Bgdk2GPEsxXyA9st3DaRFB6bRNXQa', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_bZpN3tcL7reJvaSfcJhWIUaj', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='The largest city in Mexico is Mexico City.')], + usage=Usage( + requests=1, + request_tokens=63, + response_tokens=10, + total_tokens=73, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-Bgdk32OMvBlPIWjnfe4O1S4fVhYiv', + ), + ] + ) + + @pytest.mark.vcr() async def test_openai_json_schema_output(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) diff --git a/tests/test_agent.py b/tests/test_agent.py index ef004f159..12b4fb700 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -13,7 +13,7 @@ from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai._output import ToolOutput +from pydantic_ai._output import TextOutput, ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.messages import ( BinaryContent, @@ -354,7 +354,7 @@ def test_response_tuple(): m = TestModel() agent = Agent(m, output_type=tuple[str, str]) - assert agent._output_schema.allow_text_output is False # pyright: ignore[reportPrivateUsage] + assert agent._output_schema.allow_text_output == 'json' # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert result.output == snapshot(('a', 'a')) @@ -389,10 +389,28 @@ def test_response_tuple(): ) +def upcase(text: str) -> str: + return text.upper() + + @pytest.mark.parametrize( 'input_union_callable', - [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str, lambda: [Foo, str]], - ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str', '[Foo, str]'], + [ + lambda: Union[str, Foo], + lambda: Union[Foo, str], + lambda: str | Foo, + lambda: Foo | str, + lambda: [Foo, str], + lambda: [TextOutput(upcase), ToolOutput(Foo)], + ], + ids=[ + 'Union[str, Foo]', + 'Union[Foo, str]', + 'str | Foo', + 'Foo | str', + '[Foo, str]', + '[TextOutput(upcase), ToolOutput(Foo)]', + ], ) def test_response_union_allow_str(input_union_callable: Callable[[], Any]): try: @@ -414,7 +432,8 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert agent._output_schema.allow_text_output == 'plain' # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') - assert result.output == snapshot('success (no tool calls)') + assert isinstance(result.output, str) + assert result.output.lower() == snapshot('success (no tool calls)') assert got_tool_call_name == snapshot(None) assert m.last_model_request_parameters is not None @@ -449,6 +468,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: [ pytest.param('OutputType = Union[Foo, Bar]'), pytest.param('OutputType = [Foo, Bar]'), + pytest.param('OutputType = [ToolOutput(Foo), ToolOutput(Bar)]'), pytest.param('OutputType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')), pytest.param( 'OutputType: TypeAlias = Foo | Bar', @@ -847,6 +867,64 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_output_type_text_output_function_with_retry(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + if city != 'Mexico City': + raise ModelRetry('City not found, I only know Mexico City') + return Weather(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + city = 'New York City' + else: + city = 'Mexico City' + + return ModelResponse(parts=[TextPart(content=city)]) + + agent = Agent(FunctionModel(call_tool), output_type=TextOutput(get_weather)) + result = agent.run_sync('New York City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='New York City', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='New York City')], + usage=Usage(requests=1, request_tokens=53, response_tokens=3, total_tokens=56), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='City not found, I only know Mexico City', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='Mexico City')], + usage=Usage(requests=1, request_tokens=68, response_tokens=5, total_tokens=73), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ] + ) + + def test_output_type_async_function(): class Weather(BaseModel): temperature: float @@ -971,6 +1049,33 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_output_type_text_output_function(): + def say_world(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart(content='world')]) + + agent = Agent(FunctionModel(say_world), output_type=TextOutput(upcase)) + result = agent.run_sync('hello') + assert result.output == snapshot('WORLD') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + model_name='function:say_world:', + timestamp=IsDatetime(), + ), + ] + ) + + def test_output_type_handoff_to_agent(): class Weather(BaseModel): temperature: float diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 180ce2b0d..eaa2c4fa8 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -1,13 +1,15 @@ """This file is used to test static typing, it's analyzed with pyright and mypy.""" +import re from collections.abc import Awaitable from dataclasses import dataclass +from decimal import Decimal from typing import Callable, TypeAlias, Union from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool -from pydantic_ai._output import ToolOutput +from pydantic_ai._output import TextOutput, ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.tools import ToolDefinition @@ -169,21 +171,25 @@ def run_sync3() -> None: assert_type(union_agent2, Agent[None, MyUnion]) -def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> str: - return f'{x} {y}' +def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> Decimal: + return Decimal(x) + y async def foobar_plain(x: int, y: int) -> int: return x * y +def str_to_regex(text: str) -> re.Pattern[str]: + return re.compile(text) + + class MyClass: def my_method(self) -> bool: return True -str_function_agent = Agent(output_type=foobar_ctx) -assert_type(str_function_agent, Agent[None, str]) +decimal_function_agent = Agent(output_type=foobar_ctx) +assert_type(decimal_function_agent, Agent[None, Decimal]) bool_method_agent = Agent(output_type=MyClass().my_method) assert_type(bool_method_agent, Agent[None, bool]) @@ -200,10 +206,12 @@ def my_method(self) -> bool: assert_type(two_scalars_output_agent, Agent[None, int | str]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore - complex_output_agent = Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]( - output_type=[Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker] + complex_output_agent = Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]]( + output_type=[str, Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker, TextOutput(str_to_regex)] + ) + assert_type( + complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] ) - assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) else: # pyright is able to correctly infer the type here async_int_function_agent = Agent(output_type=foobar_plain) @@ -216,8 +224,12 @@ def my_method(self) -> bool: assert_type(two_scalars_output_agent, Agent[None, int | str]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore - complex_output_agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput(foobar_plain), marker]) - assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) + complex_output_agent = Agent( + output_type=[str, Foo, Bar, foobar_ctx, ToolOutput(foobar_plain), marker, TextOutput(str_to_regex)] + ) + assert_type( + complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] + ) Tool(foobar_ctx, takes_ctx=True) From 0e356a392679e987c751ad31c768d00125645df6 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 9 Jun 2025 21:22:09 +0000 Subject: [PATCH 07/90] Add VCR recording of new test --- pydantic_ai_slim/pydantic_ai/__init__.py | 6 +- pydantic_ai_slim/pydantic_ai/_output.py | 12 +- pydantic_ai_slim/pydantic_ai/result.py | 6 +- .../test_openai_text_output_function.yaml | 191 ++++++++++++++++++ tests/models/test_openai.py | 21 +- 5 files changed, 214 insertions(+), 22 deletions(-) create mode 100644 tests/models/cassettes/test_openai/test_openai_text_output_function.yaml diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index af0cb8a15..eb825a722 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -12,7 +12,7 @@ ) from .format_prompt import format_as_xml from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl -from .result import JSONSchemaOutput, ManualJSONOutput, ToolOutput +from .result import JsonSchemaOutput, ManualJsonOutput, ToolOutput from .tools import RunContext, Tool __all__ = ( @@ -43,8 +43,8 @@ 'RunContext', # result 'ToolOutput', - 'JSONSchemaOutput', - 'ManualJSONOutput', + 'JsonSchemaOutput', + 'ManualJsonOutput', # format_prompt 'format_as_xml', ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 19ed724e8..7deed39b9 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -161,7 +161,7 @@ class TextOutput(Generic[OutputDataT]): @dataclass(init=False) -class JSONSchemaOutput(Generic[OutputDataT]): +class JsonSchemaOutput(Generic[OutputDataT]): """Marker class to use JSON schema output for outputs.""" output_types: Sequence[OutputTypeOrFunction[OutputDataT]] @@ -183,7 +183,7 @@ def __init__( self.strict = strict -class ManualJSONOutput(Generic[OutputDataT]): +class ManualJsonOutput(Generic[OutputDataT]): """Marker class to use manual JSON mode for outputs.""" output_types: Sequence[OutputTypeOrFunction[OutputDataT]] @@ -214,8 +214,8 @@ def __init__( ToolOutput[T_co], TextOutput[T_co], Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], - JSONSchemaOutput[T_co], - ManualJSONOutput[T_co], + JsonSchemaOutput[T_co], + ManualJsonOutput[T_co], ], type_params=(T_co,), ) @@ -255,7 +255,7 @@ def __init__( self.text_output_schema = OutputTextSchema(output_type) return - if isinstance(output_type, JSONSchemaOutput): + if isinstance(output_type, JsonSchemaOutput): self.mode = 'json_schema' self.text_output_schema = self._build_text_output_schema( output_type.output_types, @@ -265,7 +265,7 @@ def __init__( ) return - if isinstance(output_type, ManualJSONOutput): + if isinstance(output_type, ManualJsonOutput): self.mode = 'manual_json' self.text_output_schema = self._build_text_output_schema( output_type.output_types, name=output_type.name, description=output_type.description diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index de535aba0..b99506c6f 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -11,8 +11,8 @@ from . import _utils, exceptions, messages as _messages, models from ._output import ( - JSONSchemaOutput, - ManualJSONOutput, + JsonSchemaOutput, + ManualJsonOutput, OutputDataT, OutputDataT_inv, OutputSchema, @@ -24,7 +24,7 @@ from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'JSONSchemaOutput', 'ManualJSONOutput', 'OutputValidatorFunc' +__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'JsonSchemaOutput', 'ManualJsonOutput', 'OutputValidatorFunc' T = TypeVar('T') diff --git a/tests/models/cassettes/test_openai/test_openai_text_output_function.yaml b/tests/models/cassettes/test_openai/test_openai_text_output_function.yaml new file mode 100644 index 000000000..9a2f3c06f --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_text_output_function.yaml @@ -0,0 +1,191 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '303' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '432' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_J1YabdC7G7kzEZNbbZopwenH + type: function + created: 1749504053 + id: chatcmpl-BgeDFS85bfHosRFEEAvq8reaCPCZ8 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 42 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 53 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '534' + content-type: + - application/json + cookie: + - __cf_bm=YTub3t5GuFdFQZLwCTHT2eGO.fT0zx3Sk2kEY.wvtik-1749504053-1.0.1.1-BMg98yRknUs3LAtnRn_3w1W2X4aoKkKWHIwaBFv.1bdfOF._ZCV0pIGVcI1saCXHR9BMUfQzhTdEPeLlXocUxVzzYQCNTOAxf21UZXcs.ks; + _cfuvid=u8gIns9XYwRGSqmjviw_hUFmKp.LpNqiNvoFMcyyK40-1749504053813-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_J1YabdC7G7kzEZNbbZopwenH + type: function + - content: Mexico + role: tool + tool_call_id: call_J1YabdC7G7kzEZNbbZopwenH + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '844' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '449' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: The largest city in Mexico is Mexico City. + refusal: null + role: assistant + created: 1749504054 + id: chatcmpl-BgeDGX9eDyVrEI56aP2vtIHahBzFH + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 10 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 63 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 73 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 9b58a3544..c644e03f8 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -15,7 +15,7 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior -from pydantic_ai._output import ManualJSONOutput, TextOutput +from pydantic_ai._output import ManualJsonOutput, TextOutput from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -35,7 +35,7 @@ from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import JSONSchemaOutput, ToolOutput, Usage +from pydantic_ai.result import JsonSchemaOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition @@ -1847,6 +1847,7 @@ async def get_user_country() -> str: ) +@pytest.mark.vcr() async def test_openai_text_output_function(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) @@ -1874,7 +1875,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ - ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_bZpN3tcL7reJvaSfcJhWIUaj') + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_J1YabdC7G7kzEZNbbZopwenH') ], usage=Usage( requests=1, @@ -1891,14 +1892,14 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgdk2GPEsxXyA9st3DaRFB6bRNXQa', + vendor_id='chatcmpl-BgeDFS85bfHosRFEEAvq8reaCPCZ8', ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='call_bZpN3tcL7reJvaSfcJhWIUaj', + tool_call_id='call_J1YabdC7G7kzEZNbbZopwenH', timestamp=IsDatetime(), ) ] @@ -1920,7 +1921,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-Bgdk32OMvBlPIWjnfe4O1S4fVhYiv', + vendor_id='chatcmpl-BgeDGX9eDyVrEI56aP2vtIHahBzFH', ), ] ) @@ -1934,7 +1935,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=JSONSchemaOutput(CityLocation)) + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -2020,7 +2021,7 @@ class CountryLanguage(BaseModel): language: str # TODO: Test with functions! - agent = Agent(m, output_type=JSONSchemaOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: @@ -2107,7 +2108,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=ManualJSONOutput(CityLocation)) + agent = Agent(m, output_type=ManualJsonOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -2211,7 +2212,7 @@ class CountryLanguage(BaseModel): language: str # TODO: Test with functions! - agent = Agent(m, output_type=ManualJSONOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=ManualJsonOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: From 81312dc0cf4209ec42fc0979f7800b62886233bf Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 00:42:24 +0000 Subject: [PATCH 08/90] Implement additional output modes in GeminiModel and GoogleModel --- pydantic_ai_slim/pydantic_ai/__init__.py | 4 +- pydantic_ai_slim/pydantic_ai/_output.py | 64 ++- pydantic_ai_slim/pydantic_ai/agent.py | 4 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 35 +- pydantic_ai_slim/pydantic_ai/models/google.py | 39 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 62 ++- .../pydantic_ai/profiles/__init__.py | 2 +- .../pydantic_ai/profiles/google.py | 2 +- .../pydantic_ai/profiles/openai.py | 10 +- pydantic_ai_slim/pydantic_ai/result.py | 13 +- .../test_google_json_schema_output.yaml | 86 ++++ ...st_google_json_schema_output_multiple.yaml | 138 +++++ .../test_google_prompted_json_output.yaml | 78 +++ ..._google_prompted_json_output_multiple.yaml | 77 +++ ...oogle_prompted_json_output_with_tools.yaml | 164 ++++++ .../test_google_text_output_function.yaml | 147 ++++++ .../test_google/test_google_tool_output.yaml | 187 +++++++ ...st_openai_json_schema_output_multiple.yaml | 22 +- ... => test_openai_prompted_json_output.yaml} | 68 ++- ...openai_prompted_json_output_multiple.yaml} | 50 +- .../test_json_schema_output.yaml | 288 +++++++++++ .../test_json_schema_output_multiple.yaml | 444 ++++++++++++++++ .../test_prompted_json_output.yaml | 70 +++ .../test_prompted_json_output_multiple.yaml | 70 +++ .../test_text_output_function.yaml | 228 +++++++++ .../test_tool_output.yaml | 282 ++++++++++ tests/models/test_google.py | 482 +++++++++++++++++- tests/models/test_openai.py | 77 ++- tests/models/test_openai_responses.py | 334 ++++++++++++ 29 files changed, 3348 insertions(+), 179 deletions(-) create mode 100644 tests/models/cassettes/test_google/test_google_json_schema_output.yaml create mode 100644 tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml create mode 100644 tests/models/cassettes/test_google/test_google_prompted_json_output.yaml create mode 100644 tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml create mode 100644 tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml create mode 100644 tests/models/cassettes/test_google/test_google_text_output_function.yaml create mode 100644 tests/models/cassettes/test_google/test_google_tool_output.yaml rename tests/models/cassettes/test_openai/{test_openai_manual_json_output.yaml => test_openai_prompted_json_output.yaml} (75%) rename tests/models/cassettes/test_openai/{test_openai_manual_json_output_multiple.yaml => test_openai_prompted_json_output_multiple.yaml} (84%) create mode 100644 tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_text_output_function.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_tool_output.yaml diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index eb825a722..43d985dc4 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -12,7 +12,7 @@ ) from .format_prompt import format_as_xml from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl -from .result import JsonSchemaOutput, ManualJsonOutput, ToolOutput +from .result import JsonSchemaOutput, PromptedJsonOutput, ToolOutput from .tools import RunContext, Tool __all__ = ( @@ -44,7 +44,7 @@ # result 'ToolOutput', 'JsonSchemaOutput', - 'ManualJsonOutput', + 'PromptedJsonOutput', # format_prompt 'format_as_xml', ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 7deed39b9..95e89a85a 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -2,6 +2,7 @@ import inspect import json +import re from collections.abc import Awaitable, Iterable, Iterator, Sequence from dataclasses import dataclass, field from textwrap import dedent @@ -54,11 +55,9 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result' DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' -DEFAULT_MANUAL_JSON_PROMPT = dedent( # TODO: Move to ModelProfile +DEFAULT_PROMPTED_JSON_PROMPT = dedent( """ - Always respond with a JSON object matching this description and schema: - - {description} + Always respond with a JSON object that's compatible with this schema: {schema} @@ -175,7 +174,7 @@ def __init__( *, name: str | None = None, description: str | None = None, - strict: bool | None = None, + strict: bool | None = True, ): self.output_types = flatten_output_types(type_) self.name = name @@ -183,7 +182,7 @@ def __init__( self.strict = strict -class ManualJsonOutput(Generic[OutputDataT]): +class PromptedJsonOutput(Generic[OutputDataT]): """Marker class to use manual JSON mode for outputs.""" output_types: Sequence[OutputTypeOrFunction[OutputDataT]] @@ -215,13 +214,13 @@ def __init__( TextOutput[T_co], Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], JsonSchemaOutput[T_co], - ManualJsonOutput[T_co], + PromptedJsonOutput[T_co], ], type_params=(T_co,), ) # TODO: Add `json_object` for old OpenAI models, or rename `json_schema` to `json` and choose automatically, relying on Pydantic validation -OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'manual_json'] +OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'prompted_json'] @dataclass(init=False) @@ -265,8 +264,8 @@ def __init__( ) return - if isinstance(output_type, ManualJsonOutput): - self.mode = 'manual_json' + if isinstance(output_type, PromptedJsonOutput): + self.mode = 'prompted_json' self.text_output_schema = self._build_text_output_schema( output_type.output_types, name=output_type.name, description=output_type.description ) @@ -368,7 +367,7 @@ def _build_text_output_schema( if len(outputs) == 1: return OutputObjectSchema(output_type=outputs[0], name=name, description=description, strict=strict) - return OutputUnionSchema(output_types=outputs, name=name, description=description, strict=strict) + return OutputUnionSchema(output_types=outputs, strict=strict) @property def allow_text_output(self) -> Literal['plain', 'json', False]: @@ -382,7 +381,7 @@ def allow_text_output(self) -> Literal['plain', 'json', False]: def is_mode_supported(self, profile: ModelProfile) -> bool: """Whether the model supports the output mode.""" mode = self.mode - if mode in ('text', 'manual_json'): + if mode in ('text', 'prompted_json'): return True if self.mode == 'tool_or_text': mode = 'tool' @@ -413,6 +412,8 @@ def tool_names(self) -> list[str]: def tool_defs(self) -> list[ToolDefinition]: """Get tool definitions to register with the model.""" + if self.mode not in ('tool', 'tool_or_text'): + return [] return [t.tool_def for t in self.tools.values()] async def process( @@ -436,7 +437,19 @@ async def process( assert self.allow_text_output is not False assert self.text_output_schema is not None - # TODO: Strip Markdown fences? + def strip_markdown_fences(text: str) -> str: + if text.startswith('{'): + return text + + regex = r'```(?:\w+)?\n(\{.*\})\n```' + match = re.search(regex, text, re.DOTALL) + if match: + return match.group(1) + + return text + + text = strip_markdown_fences(text) + return await self.text_output_schema.process( text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -444,17 +457,22 @@ async def process( @dataclass class OutputObjectDefinition: - name: str json_schema: ObjectJsonSchema + name: str | None = None description: str | None = None strict: bool | None = None @property - def manual_json_instructions(self) -> str: + def instructions(self) -> str: """Get instructions for model to output manual JSON matching the schema.""" - # TODO: Move to ModelProfile so it can be tweaked - description = ': '.join([v for v in [self.name, self.description] if v]) - return DEFAULT_MANUAL_JSON_PROMPT.format(schema=json.dumps(self.json_schema), description=description) + schema = self.json_schema.copy() + if self.name and not schema.get('title'): + schema['title'] = self.name + if self.description and not schema.get('description'): + schema['description'] = self.description + + # Eventually move DEFAULT_PROMPTED_JSON_PROMPT to ModelProfile so it can be tweaked on a per model basis + return DEFAULT_PROMPTED_JSON_PROMPT.format(schema=json.dumps(schema)) @dataclass(init=False) @@ -478,14 +496,13 @@ class OutputUnionSchema(Generic[OutputDataT]): def __init__( self, output_types: Sequence[OutputTypeOrFunction[OutputDataT]], - name: str | None = None, - description: str | None = None, strict: bool | None = None, ): self._object_schemas = {} # TODO: Ensure keys are unique self._object_schemas = { - output_type.__name__: OutputObjectSchema(output_type=output_type) for output_type in output_types + output_type.__name__: OutputObjectSchema(output_type=output_type, strict=strict) + for output_type in output_types } self._root_object_schema = OutputObjectSchema(output_type=OutputUnionData) @@ -500,6 +517,7 @@ def __init__( 'type': 'object', 'properties': { 'kind': { + 'type': 'string', 'const': name, }, 'data': object_schema.object_def.json_schema, # TODO: Pop description here? @@ -517,8 +535,6 @@ def __init__( } self.object_def = OutputObjectDefinition( - name=name or DEFAULT_OUTPUT_TOOL_NAME, - description=description or DEFAULT_OUTPUT_TOOL_DESCRIPTION, json_schema=json_schema, strict=strict, ) @@ -597,7 +613,7 @@ def __init__( description = f'{description}. {json_schema_description}' self.object_def = OutputObjectDefinition( - name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME), + name=name or getattr(output_type, '__name__'), description=description, json_schema=json_schema, strict=strict, diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 31cef815e..ac62a43a2 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -676,11 +676,11 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: ] if ( - output_schema.mode == 'manual_json' + output_schema.mode == 'prompted_json' and (output_object_schema := output_schema.text_output_schema) and (object_def := output_object_schema.object_def) ): - parts.append(object_def.manual_json_instructions) + parts.append(object_def.instructions) parts = [p for p in parts if p] if not parts: diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c85254688..1cfc5a364 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -16,6 +16,8 @@ from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._output import OutputObjectDefinition +from ..exceptions import UserError from ..messages import ( AudioUrl, BinaryContent, @@ -192,12 +194,12 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: - if model_request_parameters.output_mode != 'tool': - return None - elif tools: + if not tools: + return _tool_config([]) # pragma: no cover + elif model_request_parameters.output_mode == 'tool': return _tool_config([t['name'] for t in tools['function_declarations']]) else: - return _tool_config([]) # pragma: no cover + return None @asynccontextmanager async def _make_request( @@ -219,6 +221,19 @@ async def _make_request( if tool_config is not None: request_data['toolConfig'] = tool_config + output_mode = model_request_parameters.output_mode + if output_mode == 'json_schema': + request_data['responseMimeType'] = 'application/json' + + output_object = model_request_parameters.output_object + assert output_object is not None + request_data['responseSchema'] = self._map_response_schema(output_object) + + if tools: + raise UserError('Google does not support JSON schema output and tools at the same time.') + elif output_mode == 'prompted_json' and not tools: + request_data['responseMimeType'] = 'application/json' + generation_config = _settings_to_generation_config(model_settings) if generation_config: request_data['generationConfig'] = generation_config @@ -361,6 +376,15 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion] assert_never(item) return content + def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: + response_schema = o.json_schema.copy() + if o.name: + response_schema['title'] = o.name + if o.description: + response_schema['description'] = o.description + + return response_schema + def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -504,6 +528,9 @@ class _GeminiRequest(TypedDict): generationConfig: NotRequired[_GeminiGenerationConfig] labels: NotRequired[dict[str, str]] + responseMimeType: NotRequired[str] + responseSchema: NotRequired[dict[str, Any]] + class GeminiSafetySettings(TypedDict): """Safety settings options for Gemini model request. diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index ac837cf46..57612567f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -10,9 +10,10 @@ from typing_extensions import assert_never -from pydantic_ai.providers import Provider +from pydantic_ai._output import OutputObjectDefinition from .. import UnexpectedModelBehavior, _utils, usage +from ..exceptions import UserError from ..messages import ( AudioUrl, BinaryContent, @@ -32,6 +33,7 @@ VideoUrl, ) from ..profiles import ModelProfileSpec +from ..providers import Provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -213,9 +215,9 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None ) -> ToolConfigDict | None: - if model_request_parameters.output_mode != 'tool': - return None - elif tools: + if not tools: + return _tool_config([]) # pragma: no cover + elif model_request_parameters.output_mode == 'tool': names: list[str] = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: @@ -223,7 +225,7 @@ def _get_tool_config( names.append(name) return _tool_config(names) else: - return _tool_config([]) # pragma: no cover + return None @overload async def _generate_content( @@ -251,6 +253,22 @@ async def _generate_content( model_request_parameters: ModelRequestParameters, ) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]: tools = self._get_tools(model_request_parameters) + + output_mode = model_request_parameters.output_mode + response_mime_type = None + response_schema = None + if output_mode == 'json_schema': + response_mime_type = 'application/json' + + output_object = model_request_parameters.output_object + assert output_object is not None + response_schema = self._map_response_schema(output_object) + + if tools: + raise UserError('Google does not support JSON schema output and tools at the same time/') + elif output_mode == 'prompted_json' and not tools: + response_mime_type = 'application/json' + tool_config = self._get_tool_config(model_request_parameters, tools) system_instruction, contents = await self._map_messages(messages) @@ -268,6 +286,8 @@ async def _generate_content( labels=model_settings.get('google_labels'), tools=cast(ToolListUnionDict, tools), tool_config=tool_config, + response_mime_type=response_mime_type, + response_schema=response_schema, ) func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content @@ -383,6 +403,15 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]: assert_never(item) return content + def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: + response_schema = o.json_schema.copy() + if o.name: + response_schema['title'] = o.name + if o.description: + response_schema['description'] = o.description + + return response_schema + @dataclass class GeminiStreamedResponse(StreamedResponse): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 4c9069010..c4c903434 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -14,7 +14,7 @@ from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage -from .._output import OutputObjectDefinition +from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( AudioUrl, @@ -263,25 +263,29 @@ async def _completions_create( model_settings: OpenAIModelSettings, model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: + tools = self._get_tools(model_request_parameters) + if not tools: + tool_choice: Literal['none', 'required', 'auto'] | None = None + elif model_request_parameters.output_mode == 'tool': + tool_choice = 'required' + else: + tool_choice = 'auto' + openai_messages = await self._map_messages(messages) tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] response_format: chat.completion_create_params.ResponseFormat | None = None output_mode = model_request_parameters.output_mode - if output_mode == 'tool': - tools.extend(self._map_tool_definition(r) for r in model_request_parameters.output_tools) - elif output_mode == 'json_schema': + if output_mode == 'json_schema': output_object = model_request_parameters.output_object assert output_object is not None - response_format = self._map_output_object_definition(output_object) - - if not tools: - tool_choice: Literal['none', 'required', 'auto'] | None = None - elif model_request_parameters.output_mode == 'tool': - tool_choice = 'required' - else: - tool_choice = 'auto' + response_format = self._map_json_schema(output_object) + elif ( + output_mode == 'prompted_json' + and OpenAIModelProfile.from_profile(self.profile).openai_supports_json_object_response_format + ): + response_format = {'type': 'json_object'} try: extra_headers = model_settings.get('extra_headers', {}) @@ -417,12 +421,11 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: ) @staticmethod - def _map_output_object_definition(o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: - # TODO: Use ResponseFormatJSONObject on older models + def _map_json_schema(o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage] 'type': 'json_schema', 'json_schema': { - 'name': o.name, + 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, }, } @@ -671,7 +674,6 @@ async def _responses_create( tools = self._get_tools(model_request_parameters) tools = list(model_settings.get('openai_builtin_tools', [])) + tools - # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None elif model_request_parameters.output_mode == 'tool': @@ -682,11 +684,21 @@ async def _responses_create( instructions, openai_messages = await self._map_messages(messages) reasoning = self._get_reasoning(model_settings) + text: responses.ResponseTextConfigParam | None = None + output_mode = model_request_parameters.output_mode + if output_mode == 'json_schema': + output_object = model_request_parameters.output_object + assert output_object is not None + text = {'format': self._map_json_schema(output_object)} + elif ( + output_mode == 'prompted_json' + and OpenAIModelProfile.from_profile(self.profile).openai_supports_json_object_response_format + ): + text = {'format': {'type': 'json_object'}} + try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) - # TODO: Pass text.format = ResponseFormatTextJSONSchemaConfigParam(...): {'type': 'json_schema', 'strict': True, 'name': '...', 'schema': ...} - # TODO: Fall back on ResponseFormatJSONObject/json_object on older models? return await self.client.responses.create( input=openai_messages, model=self._model_name, @@ -702,6 +714,7 @@ async def _responses_create( timeout=model_settings.get('timeout', NOT_GIVEN), reasoning=reasoning, user=model_settings.get('openai_user', NOT_GIVEN), + text=text or NOT_GIVEN, extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), ) @@ -793,6 +806,19 @@ def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam: type='function_call', ) + @staticmethod + def _map_json_schema(o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam: + response_format_param: responses.ResponseFormatTextJSONSchemaConfigParam = { + 'type': 'json_schema', + 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, + 'schema': o.json_schema, + } + if o.description: + response_format_param['description'] = o.description + if o.strict: + response_format_param['strict'] = o.strict + return response_format_param + @staticmethod async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam: content: str | list[responses.ResponseInputContentParam] diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 743442c66..9c311d685 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -15,7 +15,7 @@ class ModelProfile: json_schema_transformer: type[JsonSchemaTransformer] | None = None output_modes: set[Literal['tool', 'json_schema']] = field(default_factory=lambda: {'tool'}) # TODO: Add docstrings - default_output_mode: Literal['tool', 'json_schema', 'manual_json'] = 'tool' + default_output_mode: Literal['tool', 'json_schema', 'prompted_json'] = 'tool' @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: diff --git a/pydantic_ai_slim/pydantic_ai/profiles/google.py b/pydantic_ai_slim/pydantic_ai/profiles/google.py index c544e185b..a0cdc61fc 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/google.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/google.py @@ -10,7 +10,7 @@ def google_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Google model.""" - return ModelProfile(json_schema_transformer=GoogleJsonSchemaTransformer) + return ModelProfile(json_schema_transformer=GoogleJsonSchemaTransformer, output_modes={'tool', 'json_schema'}) class GoogleJsonSchemaTransformer(JsonSchemaTransformer): diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index a708000fe..a2256043f 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -15,12 +15,20 @@ class OpenAIModelProfile(ModelProfile): ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. """ - # This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions openai_supports_strict_tool_definition: bool = True + """This can be set by a provider or user if the OpenAI-"compatible" API doesn't support strict tool definitions.""" + + openai_supports_json_object_response_format: bool = True + """This can be set by a provider or user if the OpenAI-"compatible" API doesn't support the `json_object` `response_format`. + Note that if a model does not support the `json_schema` `response_format`, that value should be removed from `ModelProfile.output_modes`. + """ def openai_model_profile(model_name: str) -> ModelProfile: """Get the model profile for an OpenAI model.""" + # json_schema is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later, + # but we leave it in here for all models because the `default_output_mode` is `'tool'`, so `json_schema` is only used + # when the user specifically uses the JsonSchemaOutput marker, so an error from the API is acceptable. return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer, output_modes={'tool', 'json_schema'}) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index b99506c6f..ef2c82aa4 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -12,19 +12,28 @@ from . import _utils, exceptions, messages as _messages, models from ._output import ( JsonSchemaOutput, - ManualJsonOutput, OutputDataT, OutputDataT_inv, OutputSchema, OutputValidator, OutputValidatorFunc, + PromptedJsonOutput, + TextOutput, ToolOutput, ) from .messages import AgentStreamEvent, FinalResultEvent from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'JsonSchemaOutput', 'ManualJsonOutput', 'OutputValidatorFunc' +__all__ = ( + 'OutputDataT', + 'OutputDataT_inv', + 'ToolOutput', + 'TextOutput', + 'JsonSchemaOutput', + 'PromptedJsonOutput', + 'OutputValidatorFunc', +) T = TypeVar('T') diff --git a/tests/models/cassettes/test_google/test_google_json_schema_output.yaml b/tests/models/cassettes/test_google/test_google_json_schema_output.yaml new file mode 100644 index 000000000..1d9ae0339 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_json_schema_output.yaml @@ -0,0 +1,86 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '453' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + responseSchema: + properties: + city: + type: STRING + country: + type: STRING + property_ordering: + - city + - country + required: + - city + - country + title: CityLocation + type: OBJECT + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '710' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=780 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0002309985226020217 + content: + parts: + - text: |- + { + "city": "Mexico City", + "country": "Mexico" + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: Gm9HaNr3KteI_NUPmYvnoA8 + usageMetadata: + candidatesTokenCount: 20 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 20 + promptTokenCount: 19 + promptTokensDetails: + - modality: TEXT + tokenCount: 19 + totalTokenCount: 39 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml new file mode 100644 index 000000000..74dd03c89 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml @@ -0,0 +1,138 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1200' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the primarily language spoken in Mexico? + role: user + generationConfig: + responseMimeType: application/json + responseSchema: + description: The final response which ends this conversation + properties: + result: + any_of: + - description: CityLocation + properties: + data: + properties: + city: + type: STRING + country: + type: STRING + property_ordering: + - city + - country + required: + - city + - country + type: OBJECT + kind: + enum: + - CityLocation + type: STRING + property_ordering: + - kind + - data + required: + - kind + - data + type: OBJECT + - description: CountryLanguage + properties: + data: + properties: + country: + type: STRING + language: + type: STRING + property_ordering: + - country + - language + required: + - country + - language + type: OBJECT + kind: + enum: + - CountryLanguage + type: STRING + property_ordering: + - kind + - data + required: + - kind + - data + type: OBJECT + required: + - result + title: final_result + type: OBJECT + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '800' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=884 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0002536005138055138 + content: + parts: + - text: |- + { + "result": { + "kind": "CountryLanguage", + "data": { + "country": "Mexico", + "language": "Spanish" + } + } + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: W29HaJzGMNGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 46 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 46 + promptTokenCount: 64 + promptTokensDetails: + - modality: TEXT + tokenCount: 64 + totalTokenCount: 110 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml new file mode 100644 index 000000000..3b241acae --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml @@ -0,0 +1,78 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '619' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '879' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=829 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.010130892906870161 + content: + parts: + - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], + "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 4HlHaK75MdGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 56 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 56 + promptTokenCount: 80 + promptTokensDetails: + - modality: TEXT + tokenCount: 80 + totalTokenCount: 136 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..33383473f --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml @@ -0,0 +1,77 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1341' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: user + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '758' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=734 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0008548707873732956 + content: + parts: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 6nlHaO_5GdeI_NUPmYvnoA8 + usageMetadata: + candidatesTokenCount: 27 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 27 + promptTokenCount: 241 + promptTokensDetails: + - modality: TEXT + tokenCount: 241 + totalTokenCount: 268 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml new file mode 100644 index 000000000..976533c66 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml @@ -0,0 +1,164 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '658' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + generationConfig: {} + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '653' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=3776 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: FnpHaOqcKrzQz7IPkuLo8QE + usageMetadata: + candidatesTokenCount: 12 + promptTokenCount: 123 + promptTokensDetails: + - modality: TEXT + tokenCount: 123 + thoughtsTokenCount: 266 + totalTokenCount: 401 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '967' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + - parts: + - functionCall: + args: {} + id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 + name: get_user_country + role: model + - parts: + - functionResponse: + id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 + name: get_user_country + response: + return_value: Mexico + role: user + generationConfig: {} + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '630' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1888 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: |- + ```json + {"city": "Mexico City", "country": "Mexico"} + ``` + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: GHpHaOPkI43Qz7IPxt6T2Ac + usageMetadata: + candidatesTokenCount: 18 + promptTokenCount: 154 + promptTokensDetails: + - modality: TEXT + tokenCount: 154 + thoughtsTokenCount: 94 + totalTokenCount: 266 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_text_output_function.yaml b/tests/models/cassettes/test_google/test_google_text_output_function.yaml new file mode 100644 index 000000000..ebfbfc86f --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_text_output_function.yaml @@ -0,0 +1,147 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '279' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + generationConfig: {} + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '769' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=2956 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: | + Okay, I can help with that. First, I need to determine your country. + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: J25HaLv8GvDQz7IPp_zUiQo + usageMetadata: + candidatesTokenCount: 30 + promptTokenCount: 49 + promptTokensDetails: + - modality: TEXT + tokenCount: 49 + thoughtsTokenCount: 159 + totalTokenCount: 238 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '672' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + - parts: + - text: | + Okay, I can help with that. First, I need to determine your country. + - functionCall: + args: {} + id: pyd_ai_82dd46d016b24cf999ce5d812b383f1a + name: get_user_country + role: model + - parts: + - functionResponse: + id: pyd_ai_82dd46d016b24cf999ce5d812b383f1a + name: get_user_country + response: + return_value: Mexico + role: user + generationConfig: {} + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '637' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1426 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: Based on the information I have, the largest city in Mexico is Mexico City. + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: KG5HaKT3Nc2fz7IPy9KsuQU + usageMetadata: + candidatesTokenCount: 16 + promptTokenCount: 98 + promptTokensDetails: + - modality: TEXT + tokenCount: 98 + thoughtsTokenCount: 45 + totalTokenCount: 159 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_tool_output.yaml b/tests/models/cassettes/test_google/test_google_tool_output.yaml new file mode 100644 index 000000000..bebefdf6a --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_tool_output.yaml @@ -0,0 +1,187 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '568' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? + role: user + generationConfig: {} + toolConfig: + functionCallingConfig: + allowedFunctionNames: + - get_user_country + - final_result + mode: ANY + tools: + - functionDeclarations: + - description: '' + name: get_user_country + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: STRING + country: + type: STRING + required: + - city + - country + type: OBJECT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '733' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=644 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: 5.670217797160149e-06 + content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: F21HaLmGI5m2nvgP-__7yAg + usageMetadata: + candidatesTokenCount: 5 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 5 + promptTokenCount: 32 + promptTokensDetails: + - modality: TEXT + tokenCount: 32 + totalTokenCount: 37 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '877' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? + role: user + - parts: + - functionCall: + args: {} + id: pyd_ai_9bbd9b896939438e8ff5aba64fed8674 + name: get_user_country + role: model + - parts: + - functionResponse: + id: pyd_ai_9bbd9b896939438e8ff5aba64fed8674 + name: get_user_country + response: + return_value: Mexico + role: user + generationConfig: {} + toolConfig: + functionCallingConfig: + allowedFunctionNames: + - get_user_country + - final_result + mode: ANY + tools: + - functionDeclarations: + - description: '' + name: get_user_country + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: STRING + country: + type: STRING + required: + - city + - country + type: OBJECT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '821' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=531 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -3.289666346972808e-05 + content: + parts: + - functionCall: + args: + city: Mexico City + country: Mexico + name: final_result + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: GG1HaMXtBoW8nvgPkaDy0Ag + usageMetadata: + candidatesTokenCount: 8 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 8 + promptTokenCount: 46 + promptTokensDetails: + - modality: TEXT + tokenCount: 46 + totalTokenCount: 54 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml index bda52f9bf..d01e28ab0 100644 --- a/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml +++ b/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml @@ -97,7 +97,7 @@ interactions: openai-organization: - pydantic-28gund openai-processing-ms: - - '999' + - '868' openai-version: - '2020-10-01' strict-transport-security: @@ -118,10 +118,10 @@ interactions: - function: arguments: '{}' name: get_user_country - id: call_NiLmkD3Yi30ax2IY7t14e3AP + id: call_SIttSeiOistt33Htj4oiHOOX type: function - created: 1748919916 - id: chatcmpl-BeCFgrwfENi1OwavvP8itSOMTKjwY + created: 1749511286 + id: chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD model: gpt-4o-2024-08-06 object: chat.completion service_tier: default @@ -154,8 +154,8 @@ interactions: content-type: - application/json cookie: - - __cf_bm=4QxpNnP8_u4FyljGgYUUF5NYWCyqa2OJvgxKjyMEh4Y-1748919917-1.0.1.1-LsEozdwJs4K6NOAxtY3kw9dzZ6JHe.l4h4qkIENfShBXiUE6C5V9ED_hCbYeM.GMdC13g7SAlw1iuh5HCTMtOzNvTr_j_jvPbLY3p35HCbM; - _cfuvid=.3C6J8WR_NWUd_EBaQOj9bFncgO.R9A8576Zi3GczTg-1748919917308-0.0.1.1-604800000 + - __cf_bm=OFzdr.HrmtC0DNdnfrTQYsK8_PwAVR9GUqjYSCgwtVM-1749511286-1.0.1.1-9_dbth7ET4rzl01UDRTw3fY1nJ20FnMCC0BBmd57gzKF8n5DnNQaI4K1mT.23nn9IUsMyHAZUNn6t1EML3d7GfGJyiLZOxrTWaqacALgzlM; + _cfuvid=f32dQYPsRd6Jc7kg.3hHa1QYAyG8f_aMMXUF.bC6gmY-1749511286914-0.0.1.1-604800000 host: - api.openai.com method: POST @@ -168,11 +168,11 @@ interactions: - function: arguments: '{}' name: get_user_country - id: call_NiLmkD3Yi30ax2IY7t14e3AP + id: call_SIttSeiOistt33Htj4oiHOOX type: function - content: Mexico role: tool - tool_call_id: call_NiLmkD3Yi30ax2IY7t14e3AP + tool_call_id: call_SIttSeiOistt33Htj4oiHOOX model: gpt-4o response_format: json_schema: @@ -252,7 +252,7 @@ interactions: openai-organization: - pydantic-28gund openai-processing-ms: - - '867' + - '920' openai-version: - '2020-10-01' strict-transport-security: @@ -269,8 +269,8 @@ interactions: content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' refusal: null role: assistant - created: 1748919918 - id: chatcmpl-BeCFiQtbjmFUzYbmXlkAEWbc0peoL + created: 1749511287 + id: chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd model: gpt-4o-2024-08-06 object: chat.completion service_tier: default diff --git a/tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml similarity index 75% rename from tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml rename to tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml index 56023f426..4eed79085 100644 --- a/tests/models/cassettes/test_openai/test_openai_manual_json_output.yaml +++ b/tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '627' + - '690' content-type: - application/json host: @@ -16,21 +16,20 @@ interactions: method: POST parsed_body: messages: - - content: |2 + - content: |- + Always respond with a JSON object that's compatible with this schema: - Always respond with a JSON object matching this description and schema: - - CityLocation - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"} + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} Don't include any text or Markdown fencing before or after. role: system - content: What is the largest city in the user country? role: user model: gpt-4o - n: 1 + response_format: + type: json_object stream: false + tool_choice: auto tools: - function: description: '' @@ -56,7 +55,7 @@ interactions: openai-organization: - pydantic-28gund openai-processing-ms: - - '430' + - '569' openai-version: - '2020-10-01' strict-transport-security: @@ -77,26 +76,26 @@ interactions: - function: arguments: '{}' name: get_user_country - id: call_uTjt2vMkeTr0GYqQyQYrUUhl + id: call_s7oT9jaLAsEqTgvxZTmFh0wB type: function - created: 1747154400 - id: chatcmpl-BWmxcJf2wXM37xTze50IfAoyuaoKb + created: 1749514895 + id: chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw model: gpt-4o-2024-08-06 object: chat.completion service_tier: default - system_fingerprint: fp_55d88aaf2f + system_fingerprint: fp_07871e2ad8 usage: - completion_tokens: 12 + completion_tokens: 11 completion_tokens_details: accepted_prediction_tokens: 0 audio_tokens: 0 reasoning_tokens: 0 rejected_prediction_tokens: 0 - prompt_tokens: 106 + prompt_tokens: 109 prompt_tokens_details: audio_tokens: 0 cached_tokens: 0 - total_tokens: 118 + total_tokens: 120 status: code: 200 message: OK @@ -109,24 +108,21 @@ interactions: connection: - keep-alive content-length: - - '858' + - '921' content-type: - application/json cookie: - - __cf_bm=95NT6qevASASUyV3RVHQoxZGp8lnU1dQzcdShJ0rQ8o-1747154400-1.0.1.1-zowTt2i3mTZlYQ8gezUuRRLY_0dw6L6iD5qfaNySs0KmHmLd2JFwYun1kZJ61S03BecMhUdxy.FiOWLq2LdY.RuTR7vePLyoCrMmCDa4vpk; - _cfuvid=hgD2spnngVs.0HuyvQx7_W1uCro2gMmGvsKkZTUk3H0-1747154400314-0.0.1.1-604800000 + - __cf_bm=jcec.FXQ2vs1UTNFhcDbuMrvzdFu7d7L1To24_vRFiQ-1749514896-1.0.1.1-PEeul2ZYkvLFmEXXk4Xlgvun2HcuGEJ0UUliLVWKx17kMCjZ8WiZbB2Yavq3RRGlxsJZsAWIVMQQ10Vb_2aqGVtQ2aiYTlnDMX3Ktkuciyk; + _cfuvid=zanrNpp5OAiS0wLKfkW9LCs3qTO2FvIaiBZptR_D2P0-1749514896187-0.0.1.1-604800000 host: - api.openai.com method: POST parsed_body: messages: - - content: |2 - - Always respond with a JSON object matching this description and schema: - - CityLocation + - content: |- + Always respond with a JSON object that's compatible with this schema: - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"} + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} Don't include any text or Markdown fencing before or after. role: system @@ -137,14 +133,16 @@ interactions: - function: arguments: '{}' name: get_user_country - id: call_uTjt2vMkeTr0GYqQyQYrUUhl + id: call_s7oT9jaLAsEqTgvxZTmFh0wB type: function - content: Mexico role: tool - tool_call_id: call_uTjt2vMkeTr0GYqQyQYrUUhl + tool_call_id: call_s7oT9jaLAsEqTgvxZTmFh0wB model: gpt-4o - n: 1 + response_format: + type: json_object stream: false + tool_choice: auto tools: - function: description: '' @@ -170,7 +168,7 @@ interactions: openai-organization: - pydantic-28gund openai-processing-ms: - - '2453' + - '718' openai-version: - '2020-10-01' strict-transport-security: @@ -187,24 +185,24 @@ interactions: content: '{"city":"Mexico City","country":"Mexico"}' refusal: null role: assistant - created: 1747154401 - id: chatcmpl-BWmxdIs5pO5RCbQ9qRxxtWVItB4NU + created: 1749514896 + id: chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0 model: gpt-4o-2024-08-06 object: chat.completion service_tier: default - system_fingerprint: fp_d8864f8b6b + system_fingerprint: fp_07871e2ad8 usage: - completion_tokens: 12 + completion_tokens: 11 completion_tokens_details: accepted_prediction_tokens: 0 audio_tokens: 0 reasoning_tokens: 0 rejected_prediction_tokens: 0 - prompt_tokens: 127 + prompt_tokens: 130 prompt_tokens_details: audio_tokens: 0 cached_tokens: 0 - total_tokens: 139 + total_tokens: 141 status: code: 200 message: OK diff --git a/tests/models/cassettes/test_openai/test_openai_manual_json_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml similarity index 84% rename from tests/models/cassettes/test_openai/test_openai_manual_json_output_multiple.yaml rename to tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml index 1c143bc6c..3d3ba886a 100644 --- a/tests/models/cassettes/test_openai/test_openai_manual_json_output_multiple.yaml +++ b/tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '1438' + - '1412' content-type: - application/json host: @@ -17,9 +17,7 @@ interactions: parsed_body: messages: - content: |- - Always respond with a JSON object matching this description and schema: - - final_result: The final response which ends this conversation + Always respond with a JSON object that's compatible with this schema: {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} @@ -28,6 +26,8 @@ interactions: - content: What is the largest city in the user country? role: user model: gpt-4o + response_format: + type: json_object stream: false tool_choice: auto tools: @@ -55,7 +55,7 @@ interactions: openai-organization: - pydantic-28gund openai-processing-ms: - - '549' + - '428' openai-version: - '2020-10-01' strict-transport-security: @@ -76,14 +76,14 @@ interactions: - function: arguments: '{}' name: get_user_country - id: call_W3kwVF2ZX9cZ2L9NnbSDSs3V + id: call_wJD14IyJ4KKVtjCrGyNCHO09 type: function - created: 1748928408 - id: chatcmpl-BeESepFYXE1ELAEvKlRNvsRzzYkOg + created: 1749514898 + id: chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR model: gpt-4o-2024-08-06 object: chat.completion service_tier: default - system_fingerprint: fp_a288987b44 + system_fingerprint: fp_9bddfca6e2 usage: completion_tokens: 11 completion_tokens_details: @@ -91,11 +91,11 @@ interactions: audio_tokens: 0 reasoning_tokens: 0 rejected_prediction_tokens: 0 - prompt_tokens: 284 + prompt_tokens: 273 prompt_tokens_details: audio_tokens: 0 cached_tokens: 0 - total_tokens: 295 + total_tokens: 284 status: code: 200 message: OK @@ -108,21 +108,19 @@ interactions: connection: - keep-alive content-length: - - '1669' + - '1643' content-type: - application/json cookie: - - __cf_bm=bTNTDMnAXJaO9CpT6TAFSQOHtSGvSowNKQi20qZpxKU-1748928408-1.0.1.1-RtnC2hZt1TL38SpAb3dXL5bL7Q7EuNYgc0.18VudHVT7WzCWgkYNSYscp6aLzd9yCQaDu__K1Q2tP05IZqU3y2KwNh0JTcuf.o8GEEG_kbY; - _cfuvid=6ycbGPnBJWGLcM9mdMSLys2eUHapvWyYqbWltcaQhkQ-1748928408894-0.0.1.1-604800000 + - __cf_bm=gqjIEMZSez95CPkkPVuU_AoDutHrobFMbFPjq43G66M-1749514899-1.0.1.1-5TGB9WajW5pzCRtVtWeQfiwyQUZs1JwWy9qC8VGlgq7s5pQWKerukQtYB7GqNDrdb.1pbtFyt2HZ9xV3YiSbK4H1bZS_hS1CCeoup_3IQW0; + _cfuvid=ZN6eoNau4b.bJ8kvRn2z9R0HgTUd9nOsupKUtLXQowU-1749514899280-0.0.1.1-604800000 host: - api.openai.com method: POST parsed_body: messages: - content: |- - Always respond with a JSON object matching this description and schema: - - final_result: The final response which ends this conversation + Always respond with a JSON object that's compatible with this schema: {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} @@ -135,12 +133,14 @@ interactions: - function: arguments: '{}' name: get_user_country - id: call_W3kwVF2ZX9cZ2L9NnbSDSs3V + id: call_wJD14IyJ4KKVtjCrGyNCHO09 type: function - content: Mexico role: tool - tool_call_id: call_W3kwVF2ZX9cZ2L9NnbSDSs3V + tool_call_id: call_wJD14IyJ4KKVtjCrGyNCHO09 model: gpt-4o + response_format: + type: json_object stream: false tool_choice: auto tools: @@ -168,7 +168,7 @@ interactions: openai-organization: - pydantic-28gund openai-processing-ms: - - '900' + - '763' openai-version: - '2020-10-01' strict-transport-security: @@ -185,12 +185,12 @@ interactions: content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' refusal: null role: assistant - created: 1748928409 - id: chatcmpl-BeESfsdYDCQwP7kGM4r5i8KXwllkT + created: 1749514899 + id: chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC model: gpt-4o-2024-08-06 object: chat.completion service_tier: default - system_fingerprint: fp_a288987b44 + system_fingerprint: fp_9bddfca6e2 usage: completion_tokens: 21 completion_tokens_details: @@ -198,11 +198,11 @@ interactions: audio_tokens: 0 reasoning_tokens: 0 rejected_prediction_tokens: 0 - prompt_tokens: 305 + prompt_tokens: 294 prompt_tokens_details: audio_tokens: 0 cached_tokens: 0 - total_tokens: 326 + total_tokens: 315 status: code: 200 message: OK diff --git a/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml b/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml new file mode 100644 index 000000000..9fd1b6989 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml @@ -0,0 +1,288 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '533' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1808' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '636' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516047 + error: null + id: resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_tTAThu8l2S9hNky2krdwijGP + id: fc_68477f0fa7c081a19a525f7c6f180f310b8591d9001d2329 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 66 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 78 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '769' + content-type: + - application/json + cookie: + - __cf_bm=My3TWVEPFsaYcjJ.iWxTB6P67jFSuxSF.n13qHpH9BA-1749516047-1.0.1.1-2bg2ltV1yu2uhfqewI9eEG1ulzfU_gq8pLx9YwHte33BTk2PgxBwaRdyegdEs_dVkAbaCoAPsQRIQmW21QPf_U2Fd1vdibnoExA_.rvTYv8; + _cfuvid=_7XoQBGwU.UsQgiPHVWMTXLLbADtbSwhrO9PY7I_3Dw-1749516047790-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_tTAThu8l2S9hNky2krdwijGP + name: get_user_country + type: function_call + - call_id: call_tTAThu8l2S9hNky2krdwijGP + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1902' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '883' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516047 + error: null + id: resp_68477f0fde708192989000a62809c6e5020197534e39cc1f + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"city":"Mexico City","country":"Mexico"}' + type: output_text + id: msg_68477f10846c81929f1e833b0785e6f3020197534e39cc1f + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 89 + input_tokens_details: + cached_tokens: 0 + output_tokens: 16 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 105 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml new file mode 100644 index 000000000..9c411f3c7 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml @@ -0,0 +1,444 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1143' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '3657' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '562' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516048 + error: null + id: resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + id: fc_68477f1168a081a3981e847cd94275080dd57d732903c563 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 153 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 165 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1379' + content-type: + - application/json + cookie: + - __cf_bm=3Nl1ERbtfVAI.dGjzCYYN1u71YD5eEoLU0iCrvPPPL0-1749516049-1.0.1.1-LnI7tJwKr.C_wA15Shsl8pcGd32zrRqqv_9u4S84nXtNCopx1iBIKYDsyMg3u1Z3lJ_1Cd1YVM8uKAMjiKmgoqS8GFQ3Z_vV_Mahvqbi4KA; + _cfuvid=oc_k9l86fnMo2ml.0aop6a3eVJEvjxB0lnxWK0_kJq8-1749516049524-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + name: get_user_country + type: function_call + - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '3800' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1042' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516049 + error: null + id: resp_68477f119830819da162aa6e10552035061ad97e2eef7871 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + type: output_text + id: msg_68477f1235b8819d898adc64709c7ebf061ad97e2eef7871 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 176 + input_tokens_details: + cached_tokens: 0 + output_tokens: 26 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 202 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml new file mode 100644 index 000000000..b2f1f6c8a --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml @@ -0,0 +1,70 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '676' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + instructions: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '224' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '13' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + parsed_body: + error: + code: null + message: Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. + param: input + type: invalid_request_error + status: + code: 400 + message: Bad Request +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..0be6e5777 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml @@ -0,0 +1,70 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1442' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + instructions: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '224' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '39' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + parsed_body: + error: + code: null + message: Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. + param: input + type: invalid_request_error + status: + code: 400 + message: Bad Request +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_text_output_function.yaml b/tests/models/cassettes/test_openai_responses/test_text_output_function.yaml new file mode 100644 index 000000000..ff4ff9acf --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_text_output_function.yaml @@ -0,0 +1,228 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '302' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1399' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '490' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516045 + error: null + id: resp_68477f0d9494819ea4f123bba707c9ee0356a60c98816d6a + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_aTJhYjzmixZaVGqwl5gn2Ncr + id: fc_68477f0dff5c819ea17a1ffbaea621e00356a60c98816d6a + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: text + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 36 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 48 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '538' + content-type: + - application/json + cookie: + - __cf_bm=JZXeUMfyA2MKPG61ecku4K0wMqhJgj2ih66RpjtdqZk-1749516046-1.0.1.1-ZF5eievVR.Y5iPpLK_dVCJNl_ANFmiDhY4iZDFbopdvvhXnZvwLMCQVFWg.S.nQ0TvOw0it63SRuHbjo3jcjuD0lnI5oRQBJUOLiQElZ_j4; + _cfuvid=K7T3n3fgO8pCRHtSCoIpwTW2UEh0En8ro1rV5aPciGo-1749516046095-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_aTJhYjzmixZaVGqwl5gn2Ncr + name: get_user_country + type: function_call + - call_id: call_aTJhYjzmixZaVGqwl5gn2Ncr + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1485' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '825' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516046 + error: null + id: resp_68477f0e2b28819d9c828ef4ee526d6a03434b607c02582d + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: The largest city in Mexico is Mexico City. + type: output_text + id: msg_68477f0ebf54819d88a44fa87aadaff503434b607c02582d + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: text + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 59 + input_tokens_details: + cached_tokens: 0 + output_tokens: 11 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 70 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_tool_output.yaml b/tests/models/cassettes/test_openai_responses/test_tool_output.yaml new file mode 100644 index 000000000..bc201f7c1 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_tool_output.yaml @@ -0,0 +1,282 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '556' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + tool_choice: required + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1854' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '568' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516043 + error: null + id: resp_68477f0b40a8819cb8d55594bc2c232a001fd29e2d5573f7 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_ZWkVhdUjupo528U9dqgFeRkH + id: fc_68477f0bb8e4819cba6d781e174d77f8001fd29e2d5573f7 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: text + tool_choice: required + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 62 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 74 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '792' + content-type: + - application/json + cookie: + - __cf_bm=78_bxRDp8.6VLECkU4_YSNYd7PlmVGdN1E4j5KBkoOA-1749516043-1.0.1.1-Z9ZwaEzQZcS64A536kPafni6AZEqjCr1xDJ1h2WXjDrs0G_LuZPuq7Z27rs6w0.2DAk_UEY0.H.YMVFpWwe0QTOI28mlvDMbZvVsP6LT4Ug; + _cfuvid=Qym79CFc.nJ8O7pqDQfy1eFUEqIDIX3VuqfAl93F07o-1749516043838-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_ZWkVhdUjupo528U9dqgFeRkH + name: get_user_country + type: function_call + - call_id: call_ZWkVhdUjupo528U9dqgFeRkH + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + tool_choice: required + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1898' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '840' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516044 + error: null + id: resp_68477f0bfda8819ea65458cd7cc389b801dc81d4bc91f560 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{"city":"Mexico City","country":"Mexico"}' + call_id: call_iFBd0zULhSZRR908DfH73VwN + id: fc_68477f0c91cc819e8024e7e633f0f09401dc81d4bc91f560 + name: final_result + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: text + tool_choice: required + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 85 + input_tokens_details: + cached_tokens: 0 + output_tokens: 20 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 105 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 05bcecee5..cc6feb6fb 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -9,11 +9,12 @@ from dirty_equals import IsInstance from httpx import Request from inline_snapshot import snapshot +from pydantic import BaseModel from pytest_mock import MockerFixture from typing_extensions import TypedDict from pydantic_ai.agent import Agent -from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior +from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError from pydantic_ai.messages import ( BinaryContent, DocumentUrl, @@ -34,7 +35,7 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.usage import Usage +from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage from ..conftest import IsDatetime, IsStr, try_import @@ -591,3 +592,480 @@ async def test_google_model_empty_user_prompt(allow_model_requests: None, google assert result.output == snapshot( 'Please provide me with a question or task. I need some information to be able to help you.\n' ) + + +async def test_google_tool_output(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=32, + response_tokens=5, + total_tokens=37, + details={'text_candidates_tokens': 5, 'text_prompt_tokens': 32}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'city': 'Mexico City', 'country': 'Mexico'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage( + requests=1, + request_tokens=46, + response_tokens=8, + total_tokens=54, + details={'text_candidates_tokens': 8, 'text_prompt_tokens': 46}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +async def test_google_text_output_function(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.5-pro-preview-05-06', provider=google_provider) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot('BASED ON THE INFORMATION I HAVE, THE LARGEST CITY IN MEXICO IS MEXICO CITY.') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content='Okay, I can help with that. First, I need to determine your country.\n'), + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr()), + ], + usage=Usage( + requests=1, + request_tokens=49, + response_tokens=30, + total_tokens=238, + details={'thoughts_tokens': 159, 'text_prompt_tokens': 49}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='Based on the information I have, the largest city in Mexico is Mexico City.')], + usage=Usage( + requests=1, + request_tokens=98, + response_tokens=16, + total_tokens=159, + details={'thoughts_tokens': 45, 'text_prompt_tokens': 98}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_json_schema_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + with pytest.raises(UserError, match='Google does not support JSON schema output and tools at the same time.'): + await agent.run('What is the largest city in the user country?') + + +async def test_google_json_schema_output(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +{ + "city": "Mexico City", + "country": "Mexico" +}\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=19, + response_tokens=20, + total_tokens=39, + details={'text_candidates_tokens': 20, 'text_prompt_tokens': 19}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_json_schema_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the primarily language spoken in Mexico?') + assert result.output == snapshot(CountryLanguage(country='Mexico', language='Spanish')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the primarily language spoken in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +{ + "result": { + "kind": "CountryLanguage", + "data": { + "country": "Mexico", + "language": "Spanish" + } + } +}\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=64, + response_tokens=46, + total_tokens=110, + details={'text_candidates_tokens': 46, 'text_prompt_tokens': 64}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_prompted_json_output(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + ) + ], + usage=Usage( + requests=1, + request_tokens=80, + response_tokens=56, + total_tokens=136, + details={'text_candidates_tokens': 56, 'text_prompt_tokens': 80}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_prompted_json_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.5-pro-preview-05-06', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_user_country', args={}, tool_call_id='pyd_ai_479a74a75212414fb3c7bd2242e9b669' + ) + ], + usage=Usage( + requests=1, + request_tokens=123, + response_tokens=12, + total_tokens=401, + details={'thoughts_tokens': 266, 'text_prompt_tokens': 123}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='pyd_ai_479a74a75212414fb3c7bd2242e9b669', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +```json +{"city": "Mexico City", "country": "Mexico"} +```\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=154, + response_tokens=18, + total_tokens=266, + details={'thoughts_tokens': 94, 'text_prompt_tokens': 154}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_prompted_json_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): + import logfire + + logfire.configure() + logfire.instrument_pydantic_ai() + + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=241, + response_tokens=27, + total_tokens=268, + details={'text_candidates_tokens': 27, 'text_prompt_tokens': 241}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index c644e03f8..8d2330593 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -15,7 +15,6 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior -from pydantic_ai._output import ManualJsonOutput, TextOutput from pydantic_ai.messages import ( AudioUrl, BinaryContent, @@ -35,7 +34,7 @@ from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import JsonSchemaOutput, ToolOutput, Usage +from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition @@ -2020,15 +2019,12 @@ class CountryLanguage(BaseModel): country: str language: str - # TODO: Test with functions! agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: return 'Mexico' - # TODO: Show what response_format looks like - result = await agent.run('What is the largest city in the user country?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -2044,7 +2040,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ - ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_NiLmkD3Yi30ax2IY7t14e3AP') + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_SIttSeiOistt33Htj4oiHOOX') ], usage=Usage( requests=1, @@ -2061,14 +2057,14 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BeCFgrwfENi1OwavvP8itSOMTKjwY', + vendor_id='chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD', ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='call_NiLmkD3Yi30ax2IY7t14e3AP', + tool_call_id='call_SIttSeiOistt33Htj4oiHOOX', timestamp=IsDatetime(), ) ] @@ -2094,21 +2090,21 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BeCFiQtbjmFUzYbmXlkAEWbc0peoL', + vendor_id='chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd', ), ] ) @pytest.mark.vcr() -async def test_openai_manual_json_output(allow_model_requests: None, openai_api_key: str): +async def test_openai_prompted_json_output(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=ManualJsonOutput(CityLocation)) + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -2127,9 +2123,7 @@ async def get_user_country() -> str: ) ], instructions="""\ -Always respond with a JSON object matching this description and schema: - -CityLocation +Always respond with a JSON object that's compatible with this schema: {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} @@ -2138,13 +2132,13 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ - ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_uTjt2vMkeTr0GYqQyQYrUUhl') + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB') ], usage=Usage( requests=1, - request_tokens=106, - response_tokens=12, - total_tokens=118, + request_tokens=109, + response_tokens=11, + total_tokens=120, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2155,21 +2149,19 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BWmxcJf2wXM37xTze50IfAoyuaoKb', + vendor_id='chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw', ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='call_uTjt2vMkeTr0GYqQyQYrUUhl', + tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB', timestamp=IsDatetime(), ) ], instructions="""\ -Always respond with a JSON object matching this description and schema: - -CityLocation +Always respond with a JSON object that's compatible with this schema: {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} @@ -2180,9 +2172,9 @@ async def get_user_country() -> str: parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], usage=Usage( requests=1, - request_tokens=127, - response_tokens=12, - total_tokens=139, + request_tokens=130, + response_tokens=11, + total_tokens=141, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2193,14 +2185,14 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BWmxdIs5pO5RCbQ9qRxxtWVItB4NU', + vendor_id='chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0', ), ] ) @pytest.mark.vcr() -async def test_openai_manual_json_output_multiple(allow_model_requests: None, openai_api_key: str): +async def test_openai_prompted_json_output_multiple(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -2211,15 +2203,12 @@ class CountryLanguage(BaseModel): country: str language: str - # TODO: Test with functions! - agent = Agent(m, output_type=ManualJsonOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: return 'Mexico' - # TODO: Show what response_format looks like - result = await agent.run('What is the largest city in the user country?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -2233,9 +2222,7 @@ async def get_user_country() -> str: ) ], instructions="""\ -Always respond with a JSON object matching this description and schema: - -final_result: The final response which ends this conversation +Always respond with a JSON object that's compatible with this schema: {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} @@ -2244,13 +2231,13 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ - ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_W3kwVF2ZX9cZ2L9NnbSDSs3V') + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09') ], usage=Usage( requests=1, - request_tokens=284, + request_tokens=273, response_tokens=11, - total_tokens=295, + total_tokens=284, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2261,21 +2248,19 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BeESepFYXE1ELAEvKlRNvsRzzYkOg', + vendor_id='chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR', ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='call_W3kwVF2ZX9cZ2L9NnbSDSs3V', + tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09', timestamp=IsDatetime(), ) ], instructions="""\ -Always respond with a JSON object matching this description and schema: - -final_result: The final response which ends this conversation +Always respond with a JSON object that's compatible with this schema: {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} @@ -2290,9 +2275,9 @@ async def get_user_country() -> str: ], usage=Usage( requests=1, - request_tokens=305, + request_tokens=294, response_tokens=21, - total_tokens=326, + total_tokens=315, details={ 'accepted_prediction_tokens': 0, 'audio_tokens': 0, @@ -2303,7 +2288,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), - vendor_id='chatcmpl-BeESfsdYDCQwP7kGM4r5i8KXwllkT', + vendor_id='chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC', ), ] ) diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index 736ceae08..e13afe416 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -3,6 +3,7 @@ import pytest from inline_snapshot import snapshot +from pydantic import BaseModel from typing_extensions import TypedDict from pydantic_ai.agent import Agent @@ -20,6 +21,7 @@ UserPromptPart, ) from pydantic_ai.profiles.openai import openai_model_profile +from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import Usage @@ -505,3 +507,335 @@ def test_model_profile_strict_not_supported(): 'strict': False, } ) + + +@pytest.mark.vcr() +async def test_tool_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_ZWkVhdUjupo528U9dqgFeRkH'), + ], + usage=Usage( + request_tokens=62, + response_tokens=12, + total_tokens=74, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_ZWkVhdUjupo528U9dqgFeRkH', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart( + tool_name='final_result', + args='{"city":"Mexico City","country":"Mexico"}', + tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', + ), + ], + usage=Usage( + request_tokens=85, + response_tokens=20, + total_tokens=105, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +@pytest.mark.vcr() +async def test_text_output_function(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot('THE LARGEST CITY IN MEXICO IS MEXICO CITY.') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr'), + ], + usage=Usage( + request_tokens=36, + response_tokens=12, + total_tokens=48, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='The largest city in Mexico is Mexico City.')], + usage=Usage( + request_tokens=59, + response_tokens=11, + total_tokens=70, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_json_schema_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_tTAThu8l2S9hNky2krdwijGP'), + ], + usage=Usage( + request_tokens=66, + response_tokens=12, + total_tokens=78, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_tTAThu8l2S9hNky2krdwijGP', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + request_tokens=89, + response_tokens=16, + total_tokens=105, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_json_schema_output_multiple(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_UaLahjOtaM2tTyYZLxTCbOaP'), + ], + usage=Usage( + request_tokens=153, + response_tokens=12, + total_tokens=165, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_UaLahjOtaM2tTyYZLxTCbOaP', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + ) + ], + usage=Usage( + request_tokens=176, + response_tokens=26, + total_tokens=202, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_prompted_json_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot() + + +@pytest.mark.vcr() +async def test_prompted_json_output_multiple(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot() From 52ef4d58d277084bc873b9d691350c08366a7058 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 13:13:16 +0000 Subject: [PATCH 09/90] Fix prompted_json on OpenAIResponses --- pydantic_ai_slim/pydantic_ai/models/openai.py | 7 + .../test_prompted_json_output.yaml | 206 ++++++++++++++++-- .../test_prompted_json_output_multiple.yaml | 206 ++++++++++++++++-- tests/models/test_openai_responses.py | 128 ++++++++++- 4 files changed, 517 insertions(+), 30 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index c4c903434..f634ca686 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -696,6 +696,13 @@ async def _responses_create( ): text = {'format': {'type': 'json_object'}} + if isinstance(instructions, str): + # Without this trick, we'd hit this error: + # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. + # Apparently they're only checking input messages for "JSON", not instructions. + openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) + instructions = NOT_GIVEN + try: extra_headers = model_settings.get('extra_headers', {}) extra_headers.setdefault('User-Agent', get_user_agent()) diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml index b2f1f6c8a..35783c516 100644 --- a/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml +++ b/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '676' + - '689' content-type: - application/json host: @@ -16,14 +16,143 @@ interactions: method: POST parsed_body: input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system - content: What is the largest city in the user country? role: user - instructions: |- - Always respond with a JSON object that's compatible with this schema: + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1408' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '8314' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561106 + error: null + id: resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + id: fc_68482f1b0ff081a1b37b9170ee740d1e02f8ef7f2fb42b50 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 107 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 119 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '925' + content-type: + - application/json + cookie: + - __cf_bm=8a8rNQQYozQt3YjcA61k6KGe.AlrMMrtcIvKv.D1s1E-1749561115-1.0.1.1-OFcqg8xD2_HdbeO74bU2.mLTqDuiK.ploHeu3_ITPvDlGwrVkwk8erMkHagxk4UDxACCCAygnUs1HL.F4AGjQCaZm1m2eYiMVbLqp0iQh7g; + _cfuvid=wKTRRc2dbdYNYnYwA2vRxVjUvqqkQovvKDwULW0Xwns-1749561115173-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - Don't include any text or Markdown fencing before or after. + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + name: get_user_country + type: function_call + - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + output: Mexico + type: function_call_output model: gpt-4o stream: false text: @@ -47,24 +176,73 @@ interactions: connection: - keep-alive content-length: - - '224' + - '1501' content-type: - application/json openai-organization: - pydantic-28gund openai-processing-ms: - - '13' + - '1098' openai-version: - '2020-10-01' strict-transport-security: - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked parsed_body: - error: - code: null - message: Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. - param: input - type: invalid_request_error + background: false + created_at: 1749561115 + error: null + id: resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"city":"Mexico City","country":"Mexico"}' + type: output_text + id: msg_68482f1c159081918a2405f458009a6a044fdb7d019d4115 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 130 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 142 + user: null status: - code: 400 - message: Bad Request + code: 200 + message: OK version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml index 0be6e5777..1a3b4dc00 100644 --- a/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml +++ b/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '1442' + - '1455' content-type: - application/json host: @@ -16,14 +16,143 @@ interactions: method: POST parsed_body: input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system - content: What is the largest city in the user country? role: user - instructions: |- - Always respond with a JSON object that's compatible with this schema: + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1408' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '11445' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561117 + error: null + id: resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + id: fc_68482f2889d481a199caa61de7ccb62c08e79646fe74d5ee + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 283 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 295 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1691' + content-type: + - application/json + cookie: + - __cf_bm=l95LdgPzGHw0UAhBwse9ADphgmMDWrhYqgiO4gdmSy4-1749561128-1.0.1.1-9zPIs3d5_ipszLpQ7yBaCZEStp8qoRIGFshR93V6n7Z_7AznH0MfuczwuoiaW8e6cEVeVHLhskjXScolO9gP5TmpsaFo37GRuHsHZTRgEeI; + _cfuvid=5L5qtbtbFCFzMmoVufSY.ksn06ay8AFs.UXFEv07pkY-1749561128680-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - Don't include any text or Markdown fencing before or after. + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + name: get_user_country + type: function_call + - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + output: Mexico + type: function_call_output model: gpt-4o stream: false text: @@ -47,24 +176,73 @@ interactions: connection: - keep-alive content-length: - - '224' + - '1551' content-type: - application/json openai-organization: - pydantic-28gund openai-processing-ms: - - '39' + - '2545' openai-version: - '2020-10-01' strict-transport-security: - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked parsed_body: - error: - code: null - message: Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. - param: input - type: invalid_request_error + background: false + created_at: 1749561128 + error: null + id: resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + type: output_text + id: msg_68482f296bfc81a18665547d4008ab2c06b4ab2d00d03024 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 306 + input_tokens_details: + cached_tokens: 0 + output_tokens: 22 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 328 + user: null status: - code: 400 - message: Bad Request + code: 200 + message: OK version: 1 diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index e13afe416..5ba449a9c 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -814,7 +814,67 @@ async def get_user_country() -> str: result = await agent.run('What is the largest city in the user country?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) - assert result.all_messages() == snapshot() + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_FrlL4M0CbAy8Dhv4VqF1Shom'), + ], + usage=Usage( + request_tokens=107, + response_tokens=12, + total_tokens=119, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_FrlL4M0CbAy8Dhv4VqF1Shom', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + request_tokens=130, + response_tokens=12, + total_tokens=142, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) @pytest.mark.vcr() @@ -838,4 +898,68 @@ async def get_user_country() -> str: result = await agent.run('What is the largest city in the user country?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) - assert result.all_messages() == snapshot() + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_my4OyoVXRT0m7bLWmsxcaCQI'), + ], + usage=Usage( + request_tokens=283, + response_tokens=12, + total_tokens=295, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_my4OyoVXRT0m7bLWmsxcaCQI', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + ) + ], + usage=Usage( + request_tokens=306, + response_tokens=22, + total_tokens=328, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) From fe059565cee30d92c0a1b8a230b3ec0618c587ee Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 14:01:03 +0000 Subject: [PATCH 10/90] Test output modes on Gemini and Anthropic --- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 18 +- pydantic_ai_slim/pydantic_ai/models/google.py | 4 +- .../test_anthropic_tool_output.yaml | 176 +++++++ tests/models/test_anthropic.py | 349 ++++++++++++- tests/models/test_gemini.py | 485 +++++++++++++++++- tests/models/test_google.py | 5 - 7 files changed, 1020 insertions(+), 19 deletions(-) create mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_tool_output.yaml diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 95e89a85a..22d679017 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -613,7 +613,7 @@ def __init__( description = f'{description}. {json_schema_description}' self.object_def = OutputObjectDefinition( - name=name or getattr(output_type, '__name__'), + name=name or getattr(output_type, '__name__', None), description=description, json_schema=json_schema, strict=strict, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 1cfc5a364..c8bed2335 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -195,11 +195,11 @@ def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: if not tools: - return _tool_config([]) # pragma: no cover + return None elif model_request_parameters.output_mode == 'tool': return _tool_config([t['name'] for t in tools['function_declarations']]) else: - return None + return _tool_config([]) # pragma: no cover @asynccontextmanager async def _make_request( @@ -221,20 +221,21 @@ async def _make_request( if tool_config is not None: request_data['toolConfig'] = tool_config + generation_config = _settings_to_generation_config(model_settings) + output_mode = model_request_parameters.output_mode if output_mode == 'json_schema': - request_data['responseMimeType'] = 'application/json' + generation_config['response_mime_type'] = 'application/json' output_object = model_request_parameters.output_object assert output_object is not None - request_data['responseSchema'] = self._map_response_schema(output_object) + generation_config['response_schema'] = self._map_response_schema(output_object) if tools: raise UserError('Google does not support JSON schema output and tools at the same time.') elif output_mode == 'prompted_json' and not tools: - request_data['responseMimeType'] = 'application/json' + generation_config['response_mime_type'] = 'application/json' - generation_config = _settings_to_generation_config(model_settings) if generation_config: request_data['generationConfig'] = generation_config @@ -528,9 +529,6 @@ class _GeminiRequest(TypedDict): generationConfig: NotRequired[_GeminiGenerationConfig] labels: NotRequired[dict[str, str]] - responseMimeType: NotRequired[str] - responseSchema: NotRequired[dict[str, Any]] - class GeminiSafetySettings(TypedDict): """Safety settings options for Gemini model request. @@ -589,6 +587,8 @@ class _GeminiGenerationConfig(TypedDict, total=False): frequency_penalty: float stop_sequences: list[str] thinking_config: ThinkingConfig + response_mime_type: str + response_schema: dict[str, Any] class _GeminiContent(TypedDict): diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 57612567f..4a7d4d7a6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -216,7 +216,7 @@ def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None ) -> ToolConfigDict | None: if not tools: - return _tool_config([]) # pragma: no cover + return None elif model_request_parameters.output_mode == 'tool': names: list[str] = [] for tool in tools: @@ -225,7 +225,7 @@ def _get_tool_config( names.append(name) return _tool_config(names) else: - return None + return _tool_config([]) # pragma: no cover @overload async def _generate_content( diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_tool_output.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_tool_output.yaml new file mode 100644 index 000000000..560e1f34c --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_tool_output.yaml @@ -0,0 +1,176 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '585' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + tool_choice: + type: any + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + - description: The final response which ends this conversation + input_schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + title: CityLocation + type: object + name: final_result + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '397' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - id: toolu_019pMboNVRg5jkw4PKkofQ6Y + input: {} + name: get_user_country + type: tool_use + id: msg_01EnfsDTixCmHjqvk9QarBj4 + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 445 + output_tokens: 23 + service_tier: standard + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '847' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? + type: text + role: user + - content: + - id: toolu_019pMboNVRg5jkw4PKkofQ6Y + input: {} + name: get_user_country + type: tool_use + role: assistant + - content: + - content: Mexico + is_error: false + tool_use_id: toolu_019pMboNVRg5jkw4PKkofQ6Y + type: tool_result + role: user + model: claude-3-5-sonnet-latest + stream: false + tool_choice: + type: any + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + - description: The final response which ends this conversation + input_schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + title: CityLocation + type: object + name: final_result + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '432' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - id: toolu_01V4d2H4EWp5LDM2aXaeyR6W + input: + city: Mexico City + country: Mexico + name: final_result + type: tool_use + id: msg_01Hbm5BtKzfVtWs8Eb7rCNNx + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 497 + output_tokens: 56 + service_tier: standard + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 8490d6204..1bd395ad1 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -11,6 +11,7 @@ import httpx import pytest from inline_snapshot import snapshot +from pydantic import BaseModel from pydantic_ai import Agent, ModelHTTPError, ModelRetry from pydantic_ai.messages import ( @@ -26,7 +27,7 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.result import Usage +from pydantic_ai.result import PromptedJsonOutput, TextOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from ..conftest import IsDatetime, IsNow, IsStr, TestEnv, raise_if_exception, try_import @@ -1063,3 +1064,349 @@ async def test_anthropic_model_empty_message_on_history(allow_model_requests: No What specifically would you like to know about potatoes?\ """) + + +@pytest.mark.vcr() +async def test_anthropic_tool_output(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_019pMboNVRg5jkw4PKkofQ6Y') + ], + usage=Usage( + requests=1, + request_tokens=445, + response_tokens=23, + total_tokens=468, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 445, + 'output_tokens': 23, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_01EnfsDTixCmHjqvk9QarBj4', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='toolu_019pMboNVRg5jkw4PKkofQ6Y', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'city': 'Mexico City', 'country': 'Mexico'}, + tool_call_id='toolu_01V4d2H4EWp5LDM2aXaeyR6W', + ) + ], + usage=Usage( + requests=1, + request_tokens=497, + response_tokens=56, + total_tokens=553, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 497, + 'output_tokens': 56, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_01Hbm5BtKzfVtWs8Eb7rCNNx', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='toolu_01V4d2H4EWp5LDM2aXaeyR6W', + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +async def test_anthropic_text_output_function(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot("""\ +BASED ON THE RESULT, YOU ARE LOCATED IN MEXICO. \n\ + +MEXICO CITY (CIUDAD DE MÉXICO) IS THE LARGEST CITY IN MEXICO, WITH A METROPOLITAN AREA POPULATION OF OVER 22 MILLION PEOPLE (2022 ESTIMATES). IT IS NOT ONLY THE LARGEST CITY IN MEXICO BUT ALSO ONE OF THE LARGEST METROPOLITAN AREAS IN THE WORLD. THE CITY PROPER HAS AROUND 9 MILLION INHABITANTS, WHILE THE GREATER METROPOLITAN AREA INCLUDES MANY SURROUNDING MUNICIPALITIES AND EXTENDS INTO THE STATE OF MEXICO. + +MEXICO CITY SERVES AS THE COUNTRY'S CAPITAL AND ITS MOST IMPORTANT POLITICAL, ECONOMIC, AND CULTURAL CENTER. IT'S LOCATED IN THE VALLEY OF MEXICO IN THE CENTRAL PART OF THE COUNTRY.\ +""") + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="I'll help you find the largest city in your country. Let me first check your country using the get_user_country tool." + ), + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01NtJsHFTSiiBoKpnzGBsg5C'), + ], + usage=Usage( + requests=1, + request_tokens=383, + response_tokens=66, + total_tokens=449, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 383, + 'output_tokens': 66, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_01YJQ6kNbtvNpwJboZ9peSEq', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='toolu_01NtJsHFTSiiBoKpnzGBsg5C', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +Based on the result, you are located in Mexico. \n\ + +Mexico City (Ciudad de México) is the largest city in Mexico, with a metropolitan area population of over 22 million people (2022 estimates). It is not only the largest city in Mexico but also one of the largest metropolitan areas in the world. The city proper has around 9 million inhabitants, while the greater metropolitan area includes many surrounding municipalities and extends into the State of Mexico. + +Mexico City serves as the country's capital and its most important political, economic, and cultural center. It's located in the Valley of Mexico in the central part of the country.\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=461, + response_tokens=135, + total_tokens=596, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 461, + 'output_tokens': 135, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_013YNHomHSuhmzWZq8iWjvjK', + ), + ] + ) + + +async def test_anthropic_prompted_json_output(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01FdQREaVXQbaH7JrFQaTzKb') + ], + usage=Usage( + requests=1, + request_tokens=459, + response_tokens=38, + total_tokens=497, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 459, + 'output_tokens': 38, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_01FjTTxFKgUXP2cp5s3d7fYh', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='toolu_01FdQREaVXQbaH7JrFQaTzKb', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], + usage=Usage( + requests=1, + request_tokens=510, + response_tokens=17, + total_tokens=527, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 510, + 'output_tokens': 17, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_01HrNfqrq9UGB54S5xhbCWY5', + ), + ] + ) + + +async def test_anthropic_prompted_json_output_multiple(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=281, + response_tokens=31, + total_tokens=312, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 281, + 'output_tokens': 31, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_01HyWKS3uRkhUw5mWKJY2iZN', + ), + ] + ) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index e4a965a10..5f61961ec 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -53,7 +53,7 @@ _GeminiUsageMetaData, ) from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import Usage +from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage from pydantic_ai.tools import ToolDefinition from ..conftest import ClientWithHandler, IsDatetime, IsNow, IsStr, TestEnv @@ -1398,3 +1398,486 @@ async def test_response_with_thought_part(get_gemini_client: GetGeminiClient): assert result.output == 'Hello from thought test' assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + + +async def test_gemini_tool_output(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_user_country', args={}, tool_call_id='pyd_ai_937a1fb3f3d0401ab90c03f501c4c778' + ) + ], + usage=Usage( + requests=1, + request_tokens=32, + response_tokens=5, + total_tokens=37, + details={'text_prompt_tokens': 32, 'text_candidates_tokens': 5}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='XzBIaKv5NL_WnvgP8JDUkAs', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='pyd_ai_937a1fb3f3d0401ab90c03f501c4c778', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'country': 'Mexico', 'city': 'Mexico City'}, + tool_call_id='pyd_ai_b5e34b7710534c728a9684c48a13fcbb', + ) + ], + usage=Usage( + requests=1, + request_tokens=46, + response_tokens=8, + total_tokens=54, + details={'text_prompt_tokens': 46, 'text_candidates_tokens': 8}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='YDBIaOaAFeS9nvgPzP-Y2QE', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='pyd_ai_b5e34b7710534c728a9684c48a13fcbb', + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +async def test_gemini_text_output_function(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.5-pro-preview-05-06', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot('MEXICO CITY IS THE LARGEST CITY IN MEXICO.') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_user_country', args={}, tool_call_id='pyd_ai_0676ff40327f4dadb1763dfaa982e2a7' + ) + ], + usage=Usage( + requests=1, + request_tokens=49, + response_tokens=12, + total_tokens=108, + details={'thoughts_tokens': 47, 'text_prompt_tokens': 49}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='kzBIaKSUI6Wtz7IPnIPN8Ac', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='pyd_ai_0676ff40327f4dadb1763dfaa982e2a7', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='Mexico City is the largest city in Mexico.')], + usage=Usage( + requests=1, + request_tokens=80, + response_tokens=9, + total_tokens=155, + details={'thoughts_tokens': 66, 'text_prompt_tokens': 80}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='lTBIaMDBEI6cz7IPp5vHqAo', + ), + ] + ) + + +async def test_gemini_json_schema_output_with_tools(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + with pytest.raises(UserError, match='Google does not support JSON schema output and tools at the same time.'): + await agent.run('What is the largest city in the user country?') + + +async def test_gemini_json_schema_output(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +{ + "city": "Mexico City", + "country": "Mexico" +}\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=17, + response_tokens=20, + total_tokens=37, + details={'text_prompt_tokens': 17, 'text_candidates_tokens': 20}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='dDFIaNO7EbjX1PIPzJ2c-Qo', + ), + ] + ) + + +async def test_gemini_json_schema_output_multiple(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the primarily language spoken in Mexico?') + assert result.output == snapshot(CountryLanguage(country='Mexico', language='Spanish')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the primarily language spoken in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +{ + "result": { + "data": { + "country": "Mexico", + "language": "Spanish" + }, + "kind": "CountryLanguage" + } +}\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=46, + response_tokens=46, + total_tokens=92, + details={'text_prompt_tokens': 46, 'text_candidates_tokens': 46}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='ljFIaMSvEYCK7dcP3OzRiQQ', + ), + ] + ) + + +async def test_gemini_prompted_json_output(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], + usage=Usage( + requests=1, + request_tokens=80, + response_tokens=13, + total_tokens=93, + details={'text_prompt_tokens': 80, 'text_candidates_tokens': 13}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='vjFIaLnCK9GU7dcPjoS34QI', + ), + ] + ) + + +async def test_gemini_prompted_json_output_with_tools(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.5-pro-preview-05-06', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_user_country', args={}, tool_call_id='pyd_ai_fc31eb899081445ea0a4369584e16f99' + ) + ], + usage=Usage( + requests=1, + request_tokens=123, + response_tokens=12, + total_tokens=2243, + details={'thoughts_tokens': 2108, 'text_prompt_tokens': 123}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='1DFIaLOMN_iHz7IP78G_6Ac', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='pyd_ai_fc31eb899081445ea0a4369584e16f99', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +```json +{"city": "Mexico City", "country": "Mexico"} +```\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=154, + response_tokens=18, + total_tokens=270, + details={'thoughts_tokens': 98, 'text_prompt_tokens': 154}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='1jFIaJHOJ-itz7IPvau3kQM', + ), + ] + ) + + +async def test_gemini_prompted_json_output_multiple(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=253, + response_tokens=27, + total_tokens=280, + details={'text_prompt_tokens': 253, 'text_candidates_tokens': 27}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='6TFIaNuTKoCK7dcP3OzRiQQ', + ), + ] + ) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index cc6feb6fb..a62abf188 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -1013,11 +1013,6 @@ async def get_user_country() -> str: async def test_google_prompted_json_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): - import logfire - - logfire.configure() - logfire.instrument_pydantic_ai() - m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): From 94421f3f4da09ec6d1884ea8fd1a18836353d206 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 14:24:31 +0000 Subject: [PATCH 11/90] Add VCR recordings of Gemini output mode tests --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 +- pydantic_ai_slim/pydantic_ai/models/google.py | 6 +- ...st_gemini_json_schema_output_multiple.yaml | 120 ++++++++++++ .../test_gemini_prompted_json_output.yaml | 74 +++++++ ..._gemini_prompted_json_output_multiple.yaml | 73 +++++++ ...emini_prompted_json_output_with_tools.yaml | 157 +++++++++++++++ .../test_gemini_text_output_function.yaml | 63 ++++++ .../test_gemini/test_gemini_tool_output.yaml | 183 ++++++++++++++++++ tests/models/test_gemini.py | 131 +++++-------- 9 files changed, 726 insertions(+), 87 deletions(-) create mode 100644 tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_text_output_function.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_tool_output.yaml diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c8bed2335..14c996acb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -194,12 +194,10 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: - if not tools: - return None - elif model_request_parameters.output_mode == 'tool': + if model_request_parameters.output_mode == 'tool' and tools: return _tool_config([t['name'] for t in tools['function_declarations']]) else: - return _tool_config([]) # pragma: no cover + return None @asynccontextmanager async def _make_request( diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 4a7d4d7a6..d041d13b4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -215,9 +215,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None ) -> ToolConfigDict | None: - if not tools: - return None - elif model_request_parameters.output_mode == 'tool': + if model_request_parameters.output_mode == 'tool' and tools: names: list[str] = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: @@ -225,7 +223,7 @@ def _get_tool_config( names.append(name) return _tool_config(names) else: - return _tool_config([]) # pragma: no cover + return None @overload async def _generate_content( diff --git a/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml new file mode 100644 index 000000000..3b306d133 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml @@ -0,0 +1,120 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '791' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the primarily language spoken in Mexico? + role: user + generationConfig: + response_mime_type: application/json + response_schema: + properties: + result: + anyOf: + - description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + enum: + - CityLocation + type: string + required: + - kind + - data + type: object + - description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + enum: + - CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '800' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=963 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -3.3667640072172103e-06 + content: + parts: + - text: |- + { + "result": { + "data": { + "country": "Mexico", + "language": "Spanish" + }, + "kind": "CountryLanguage" + } + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 2jxIaPucEYCK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 46 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 46 + promptTokenCount: 46 + promptTokensDetails: + - modality: TEXT + tokenCount: 46 + totalTokenCount: 92 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml new file mode 100644 index 000000000..2268e7f84 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml @@ -0,0 +1,74 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '521' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '880' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=841 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.007913463882037572 + content: + parts: + - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], + "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 2zxIaIiLE4CK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 56 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 56 + promptTokenCount: 80 + promptTokensDetails: + - modality: TEXT + tokenCount: 80 + totalTokenCount: 136 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..e96fc20d7 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml @@ -0,0 +1,73 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1287' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '757' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=823 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0030997690779191477 + content: + parts: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: Wz1IaOH5OdGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 27 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 27 + promptTokenCount: 253 + promptTokensDetails: + - modality: TEXT + tokenCount: 253 + totalTokenCount: 280 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml new file mode 100644 index 000000000..f10da3ad7 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml @@ -0,0 +1,157 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '615' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '653' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=4501 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: rj9IaPTzNdCBqtsPg-GD6QU + usageMetadata: + candidatesTokenCount: 12 + promptTokenCount: 123 + promptTokensDetails: + - modality: TEXT + tokenCount: 123 + thoughtsTokenCount: 318 + totalTokenCount: 453 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '809' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + - parts: + - functionCall: + args: {} + name: get_user_country + role: model + - parts: + - functionResponse: + name: get_user_country + response: + return_value: Mexico + role: user + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '616' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1823 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: sD9IaOCyLPqumtkP6p_T0AE + usageMetadata: + candidatesTokenCount: 13 + promptTokenCount: 154 + promptTokensDetails: + - modality: TEXT + tokenCount: 154 + thoughtsTokenCount: 94 + totalTokenCount: 261 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_text_output_function.yaml b/tests/models/cassettes/test_gemini/test_gemini_text_output_function.yaml new file mode 100644 index 000000000..7d54ce938 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_text_output_function.yaml @@ -0,0 +1,63 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '87' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '753' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=6856 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: |- + The largest city in Mexico is **Mexico City (Ciudad de México, CDMX)**. + + It's the capital of Mexico and one of the largest metropolitan areas in the world, both by population and land area. + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: TT9IaNfGN_DmqtsPzKnE4AE + usageMetadata: + candidatesTokenCount: 44 + promptTokenCount: 9 + promptTokensDetails: + - modality: TEXT + tokenCount: 9 + thoughtsTokenCount: 545 + totalTokenCount: 598 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_tool_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_tool_output.yaml new file mode 100644 index 000000000..f0c7adc68 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_tool_output.yaml @@ -0,0 +1,183 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '511' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? + role: user + toolConfig: + function_calling_config: + allowed_function_names: + - get_user_country + - final_result + mode: ANY + tools: + functionDeclarations: + - description: '' + name: get_user_country + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '733' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=591 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: 5.670217797160149e-06 + content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: SDxIaMqaGOS9nvgPzP-Y2QE + usageMetadata: + candidatesTokenCount: 5 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 5 + promptTokenCount: 32 + promptTokensDetails: + - modality: TEXT + tokenCount: 32 + totalTokenCount: 37 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '705' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? + role: user + - parts: + - functionCall: + args: {} + name: get_user_country + role: model + - parts: + - functionResponse: + name: get_user_country + response: + return_value: Mexico + role: user + toolConfig: + function_calling_config: + allowed_function_names: + - get_user_country + - final_result + mode: ANY + tools: + functionDeclarations: + - description: '' + name: get_user_country + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '821' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=613 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -3.3069271012209356e-05 + content: + parts: + - functionCall: + args: + city: Mexico City + country: Mexico + name: final_result + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: SDxIaNHrNcy3nvgPm5DhwQo + usageMetadata: + candidatesTokenCount: 8 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 8 + promptTokenCount: 46 + promptTokensDetails: + - modality: TEXT + tokenCount: 46 + totalTokenCount: 54 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 5f61961ec..d4d765c73 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1400,6 +1400,7 @@ async def test_response_with_thought_part(get_gemini_client: GetGeminiClient): assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) +@pytest.mark.vcr() async def test_gemini_tool_output(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) @@ -1427,11 +1428,7 @@ async def get_user_country() -> str: ] ), ModelResponse( - parts=[ - ToolCallPart( - tool_name='get_user_country', args={}, tool_call_id='pyd_ai_937a1fb3f3d0401ab90c03f501c4c778' - ) - ], + parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], usage=Usage( requests=1, request_tokens=32, @@ -1442,14 +1439,14 @@ async def get_user_country() -> str: model_name='gemini-2.0-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='XzBIaKv5NL_WnvgP8JDUkAs', + vendor_id=IsStr(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='pyd_ai_937a1fb3f3d0401ab90c03f501c4c778', + tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] @@ -1459,7 +1456,7 @@ async def get_user_country() -> str: ToolCallPart( tool_name='final_result', args={'country': 'Mexico', 'city': 'Mexico City'}, - tool_call_id='pyd_ai_b5e34b7710534c728a9684c48a13fcbb', + tool_call_id=IsStr(), ) ], usage=Usage( @@ -1472,14 +1469,14 @@ async def get_user_country() -> str: model_name='gemini-2.0-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='YDBIaOaAFeS9nvgPzP-Y2QE', + vendor_id=IsStr(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='final_result', content='Final result processed.', - tool_call_id='pyd_ai_b5e34b7710534c728a9684c48a13fcbb', + tool_call_id=IsStr(), timestamp=IsDatetime(), ) ] @@ -1488,6 +1485,7 @@ async def get_user_country() -> str: ) +@pytest.mark.vcr() async def test_gemini_text_output_function(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.5-pro-preview-05-06', provider=GoogleGLAProvider(api_key=gemini_api_key)) @@ -1496,71 +1494,50 @@ def upcase(text: str) -> str: agent = Agent(m, output_type=TextOutput(upcase)) - @agent.tool_plain - async def get_user_country() -> str: - return 'Mexico' + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot("""\ +THE LARGEST CITY IN MEXICO IS **MEXICO CITY (CIUDAD DE MÉXICO, CDMX)**. - result = await agent.run( - 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' - ) - assert result.output == snapshot('MEXICO CITY IS THE LARGEST CITY IN MEXICO.') +IT'S THE CAPITAL OF MEXICO AND ONE OF THE LARGEST METROPOLITAN AREAS IN THE WORLD, BOTH BY POPULATION AND LAND AREA.\ +""") assert result.all_messages() == snapshot( [ ModelRequest( parts=[ UserPromptPart( - content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + content='What is the largest city in Mexico?', timestamp=IsDatetime(), ) ] ), ModelResponse( parts=[ - ToolCallPart( - tool_name='get_user_country', args={}, tool_call_id='pyd_ai_0676ff40327f4dadb1763dfaa982e2a7' + TextPart( + content="""\ +The largest city in Mexico is **Mexico City (Ciudad de México, CDMX)**. + +It's the capital of Mexico and one of the largest metropolitan areas in the world, both by population and land area.\ +""" ) ], usage=Usage( requests=1, - request_tokens=49, - response_tokens=12, - total_tokens=108, - details={'thoughts_tokens': 47, 'text_prompt_tokens': 49}, - ), - model_name='models/gemini-2.5-pro-preview-05-06', - timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, - vendor_id='kzBIaKSUI6Wtz7IPnIPN8Ac', - ), - ModelRequest( - parts=[ - ToolReturnPart( - tool_name='get_user_country', - content='Mexico', - tool_call_id='pyd_ai_0676ff40327f4dadb1763dfaa982e2a7', - timestamp=IsDatetime(), - ) - ] - ), - ModelResponse( - parts=[TextPart(content='Mexico City is the largest city in Mexico.')], - usage=Usage( - requests=1, - request_tokens=80, - response_tokens=9, - total_tokens=155, - details={'thoughts_tokens': 66, 'text_prompt_tokens': 80}, + request_tokens=9, + response_tokens=44, + total_tokens=598, + details={'thoughts_tokens': 545, 'text_prompt_tokens': 9}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='lTBIaMDBEI6cz7IPp5vHqAo', + vendor_id='TT9IaNfGN_DmqtsPzKnE4AE', ), ] ) +@pytest.mark.vcr() async def test_gemini_json_schema_output_with_tools(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) @@ -1621,12 +1598,13 @@ class CityLocation(BaseModel): model_name='gemini-2.0-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='dDFIaNO7EbjX1PIPzJ2c-Qo', + vendor_id=IsStr(), ), ] ) +@pytest.mark.vcr() async def test_gemini_json_schema_output_multiple(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) @@ -1679,12 +1657,13 @@ class CountryLanguage(BaseModel): model_name='gemini-2.0-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='ljFIaMSvEYCK7dcP3OzRiQQ', + vendor_id=IsStr(), ), ] ) +@pytest.mark.vcr() async def test_gemini_prompted_json_output(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) @@ -1715,23 +1694,28 @@ class CityLocation(BaseModel): """, ), ModelResponse( - parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], + parts=[ + TextPart( + content='{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + ) + ], usage=Usage( requests=1, request_tokens=80, - response_tokens=13, - total_tokens=93, - details={'text_prompt_tokens': 80, 'text_candidates_tokens': 13}, + response_tokens=56, + total_tokens=136, + details={'text_prompt_tokens': 80, 'text_candidates_tokens': 56}, ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='vjFIaLnCK9GU7dcPjoS34QI', + vendor_id=IsStr(), ), ] ) +@pytest.mark.vcr() async def test_gemini_prompted_json_output_with_tools(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.5-pro-preview-05-06', provider=GoogleGLAProvider(api_key=gemini_api_key)) @@ -1768,29 +1752,25 @@ async def get_user_country() -> str: """, ), ModelResponse( - parts=[ - ToolCallPart( - tool_name='get_user_country', args={}, tool_call_id='pyd_ai_fc31eb899081445ea0a4369584e16f99' - ) - ], + parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], usage=Usage( requests=1, request_tokens=123, response_tokens=12, - total_tokens=2243, - details={'thoughts_tokens': 2108, 'text_prompt_tokens': 123}, + total_tokens=453, + details={'thoughts_tokens': 318, 'text_prompt_tokens': 123}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='1DFIaLOMN_iHz7IP78G_6Ac', + vendor_id=IsStr(), ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='pyd_ai_fc31eb899081445ea0a4369584e16f99', + tool_call_id=IsStr(), timestamp=IsDatetime(), ) ], @@ -1803,31 +1783,24 @@ async def get_user_country() -> str: """, ), ModelResponse( - parts=[ - TextPart( - content="""\ -```json -{"city": "Mexico City", "country": "Mexico"} -```\ -""" - ) - ], + parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], usage=Usage( requests=1, request_tokens=154, - response_tokens=18, - total_tokens=270, - details={'thoughts_tokens': 98, 'text_prompt_tokens': 154}, + response_tokens=13, + total_tokens=261, + details={'thoughts_tokens': 94, 'text_prompt_tokens': 154}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='1jFIaJHOJ-itz7IPvau3kQM', + vendor_id=IsStr(), ), ] ) +@pytest.mark.vcr() async def test_gemini_prompted_json_output_multiple(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) @@ -1877,7 +1850,7 @@ class CountryLanguage(BaseModel): model_name='gemini-2.0-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, - vendor_id='6TFIaNuTKoCK7dcP3OzRiQQ', + vendor_id=IsStr(), ), ] ) From 1902d006ffd47a23d71460e7af01229b060e436e Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 14:31:13 +0000 Subject: [PATCH 12/90] Remove some old TODO comments --- pydantic_ai_slim/pydantic_ai/_output.py | 3 +-- pydantic_ai_slim/pydantic_ai/agent.py | 3 --- pydantic_ai_slim/pydantic_ai/profiles/__init__.py | 4 +++- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 22d679017..c14d2c34a 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -127,7 +127,7 @@ def __init__(self, tool_retry: _messages.RetryPromptPart): class ToolOutput(Generic[OutputDataT]): """Marker class to use tools for outputs, and customize the tool.""" - output_type: OutputTypeOrFunction[OutputDataT] # TODO: Allow list of types instead of unions? + output_type: OutputTypeOrFunction[OutputDataT] name: str | None description: str | None max_retries: int | None @@ -219,7 +219,6 @@ def __init__( type_params=(T_co,), ) -# TODO: Add `json_object` for old OpenAI models, or rename `json_schema` to `json` and choose automatically, relying on Pydantic validation OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'prompted_json'] diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index ac62a43a2..c4632e804 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -329,7 +329,6 @@ def __init__( self._instructions_functions = [] if isinstance(instructions, (str, Callable)): instructions = [instructions] - # TODO: Add OutputSchema to the instructions in JSON mode for instruction in instructions or []: if isinstance(instruction, str): self._instructions += instruction + '\n' @@ -1654,8 +1653,6 @@ def _prepare_output_schema( schema = self._output_schema if schema.mode is None: - # TODO: This may need to be done later, when we know if there are any model_request_parameters.function_tools, - # as some models do not support tool calls at the same time as json_schema output, and which mode we pick may be different... schema.mode = model_profile.default_output_mode if not schema.is_mode_supported(model_profile): raise exceptions.UserError( diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 9c311d685..aad62f2a9 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -13,9 +13,11 @@ class ModelProfile: """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used.""" json_schema_transformer: type[JsonSchemaTransformer] | None = None + """The transformer to use to make JSON schemas for tools and structured output compatible with the model.""" output_modes: set[Literal['tool', 'json_schema']] = field(default_factory=lambda: {'tool'}) - # TODO: Add docstrings + """The output modes supported by the model. Essentially all models support `tool` mode, but some also support `json_schema` mode, which needs to be specifically implemented on the model class.""" default_output_mode: Literal['tool', 'json_schema', 'prompted_json'] = 'tool' + """The default output mode to use for the model.""" @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: From 1f53c9be9a239ee7ce77a928dc85e77bd4fe2a9b Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 14:47:05 +0000 Subject: [PATCH 13/90] Add missing VCR recording of Gemini output mode test --- .../test_gemini_json_schema_output.yaml | 79 +++++++++++++++++++ tests/models/test_gemini.py | 1 + 2 files changed, 80 insertions(+) create mode 100644 tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml diff --git a/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml new file mode 100644 index 000000000..d7f14c9ca --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '305' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + response_schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + title: CityLocation + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '710' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=819 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.00018302639946341515 + content: + parts: + - text: |- + { + "city": "Mexico City", + "country": "Mexico" + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: SEVIaJvJHICK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 20 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 20 + promptTokenCount: 17 + promptTokensDetails: + - modality: TEXT + tokenCount: 17 + totalTokenCount: 37 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index d4d765c73..218a4af85 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1555,6 +1555,7 @@ async def get_user_country() -> str: await agent.run('What is the largest city in the user country?') +@pytest.mark.vcr() async def test_gemini_json_schema_output(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) From a4c2877f1f1c8447be30197549cdea66b509b366 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 15:16:41 +0000 Subject: [PATCH 14/90] Add more missing VCR recordings --- .../test_anthropic_prompted_json_output.yaml | 161 ++++++++++++++++++ ...thropic_prompted_json_output_multiple.yaml | 66 +++++++ .../test_anthropic_text_output_function.yaml | 156 +++++++++++++++++ tests/models/test_anthropic.py | 45 +++-- tests/models/test_google.py | 10 +- tests/models/test_openai.py | 29 +++- 6 files changed, 432 insertions(+), 35 deletions(-) create mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml create mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml create mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_text_output_function.yaml diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml new file mode 100644 index 000000000..e88afebdf --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml @@ -0,0 +1,161 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '740' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '397' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - id: toolu_017UryVwtsKsjonhFV3cgV3X + input: {} + name: get_user_country + type: tool_use + id: msg_014CpBKzioMqUyLWrMihpvsz + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 459 + output_tokens: 38 + service_tier: standard + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1002' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + - content: + - id: toolu_017UryVwtsKsjonhFV3cgV3X + input: {} + name: get_user_country + type: tool_use + role: assistant + - content: + - content: Mexico + is_error: false + tool_use_id: toolu_017UryVwtsKsjonhFV3cgV3X + type: tool_result + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '380' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: '{"city": "Mexico City", "country": "Mexico"}' + type: text + id: msg_014JeWCouH6DpdqzMTaBdkpJ + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 510 + output_tokens: 17 + service_tier: standard + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..183daa406 --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml @@ -0,0 +1,66 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1268' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in Mexico? + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '434' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + type: text + id: msg_013ttUi3HCcKt7PkJpoWs5FT + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 281 + output_tokens: 31 + service_tier: standard + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_text_output_function.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_text_output_function.yaml new file mode 100644 index 000000000..ad365d4ac --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_text_output_function.yaml @@ -0,0 +1,156 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '409' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '540' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: I'll help you find the largest city in your country. Let me first check your country using the get_user_country + tool. + type: text + - id: toolu_01EZuxfc6MsPsPgrAKQohw3e + input: {} + name: get_user_country + type: tool_use + id: msg_014NE4yfV1Yz2vLAJzapxxef + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 383 + output_tokens: 66 + service_tier: standard + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '814' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + - content: + - text: I'll help you find the largest city in your country. Let me first check your country using the get_user_country + tool. + type: text + - id: toolu_01EZuxfc6MsPsPgrAKQohw3e + input: {} + name: get_user_country + type: tool_use + role: assistant + - content: + - content: Mexico + is_error: false + tool_use_id: toolu_01EZuxfc6MsPsPgrAKQohw3e + type: tool_result + role: user + model: claude-3-5-sonnet-latest + stream: false + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '801' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: Based on the result, you are located in Mexico. The largest city in Mexico is Mexico City (Ciudad de México), + which is also the nation's capital. Mexico City has a population of approximately 9.2 million people in the city + proper, and over 21 million people in its metropolitan area, making it one of the largest urban agglomerations in + the world. It is both the political and economic center of Mexico, located in the Valley of Mexico in the central + part of the country. + type: text + id: msg_0193srwo7TCx49h97wDwc7K7 + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 461 + output_tokens: 107 + service_tier: standard + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 1bd395ad1..068c8c116 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1161,6 +1161,7 @@ async def get_user_country() -> str: ) +@pytest.mark.vcr() async def test_anthropic_text_output_function(allow_model_requests: None, anthropic_api_key: str): m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) @@ -1176,13 +1177,9 @@ async def get_user_country() -> str: result = await agent.run( 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' ) - assert result.output == snapshot("""\ -BASED ON THE RESULT, YOU ARE LOCATED IN MEXICO. \n\ - -MEXICO CITY (CIUDAD DE MÉXICO) IS THE LARGEST CITY IN MEXICO, WITH A METROPOLITAN AREA POPULATION OF OVER 22 MILLION PEOPLE (2022 ESTIMATES). IT IS NOT ONLY THE LARGEST CITY IN MEXICO BUT ALSO ONE OF THE LARGEST METROPOLITAN AREAS IN THE WORLD. THE CITY PROPER HAS AROUND 9 MILLION INHABITANTS, WHILE THE GREATER METROPOLITAN AREA INCLUDES MANY SURROUNDING MUNICIPALITIES AND EXTENDS INTO THE STATE OF MEXICO. - -MEXICO CITY SERVES AS THE COUNTRY'S CAPITAL AND ITS MOST IMPORTANT POLITICAL, ECONOMIC, AND CULTURAL CENTER. IT'S LOCATED IN THE VALLEY OF MEXICO IN THE CENTRAL PART OF THE COUNTRY.\ -""") + assert result.output == snapshot( + "BASED ON THE RESULT, YOU ARE LOCATED IN MEXICO. THE LARGEST CITY IN MEXICO IS MEXICO CITY (CIUDAD DE MÉXICO), WHICH IS ALSO THE NATION'S CAPITAL. MEXICO CITY HAS A POPULATION OF APPROXIMATELY 9.2 MILLION PEOPLE IN THE CITY PROPER, AND OVER 21 MILLION PEOPLE IN ITS METROPOLITAN AREA, MAKING IT ONE OF THE LARGEST URBAN AGGLOMERATIONS IN THE WORLD. IT IS BOTH THE POLITICAL AND ECONOMIC CENTER OF MEXICO, LOCATED IN THE VALLEY OF MEXICO IN THE CENTRAL PART OF THE COUNTRY." + ) assert result.all_messages() == snapshot( [ @@ -1199,7 +1196,7 @@ async def get_user_country() -> str: TextPart( content="I'll help you find the largest city in your country. Let me first check your country using the get_user_country tool." ), - ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01NtJsHFTSiiBoKpnzGBsg5C'), + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01EZuxfc6MsPsPgrAKQohw3e'), ], usage=Usage( requests=1, @@ -1215,14 +1212,14 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01YJQ6kNbtvNpwJboZ9peSEq', + vendor_id='msg_014NE4yfV1Yz2vLAJzapxxef', ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='toolu_01NtJsHFTSiiBoKpnzGBsg5C', + tool_call_id='toolu_01EZuxfc6MsPsPgrAKQohw3e', timestamp=IsDatetime(), ) ] @@ -1230,35 +1227,30 @@ async def get_user_country() -> str: ModelResponse( parts=[ TextPart( - content="""\ -Based on the result, you are located in Mexico. \n\ - -Mexico City (Ciudad de México) is the largest city in Mexico, with a metropolitan area population of over 22 million people (2022 estimates). It is not only the largest city in Mexico but also one of the largest metropolitan areas in the world. The city proper has around 9 million inhabitants, while the greater metropolitan area includes many surrounding municipalities and extends into the State of Mexico. - -Mexico City serves as the country's capital and its most important political, economic, and cultural center. It's located in the Valley of Mexico in the central part of the country.\ -""" + content="Based on the result, you are located in Mexico. The largest city in Mexico is Mexico City (Ciudad de México), which is also the nation's capital. Mexico City has a population of approximately 9.2 million people in the city proper, and over 21 million people in its metropolitan area, making it one of the largest urban agglomerations in the world. It is both the political and economic center of Mexico, located in the Valley of Mexico in the central part of the country." ) ], usage=Usage( requests=1, request_tokens=461, - response_tokens=135, - total_tokens=596, + response_tokens=107, + total_tokens=568, details={ 'cache_creation_input_tokens': 0, 'cache_read_input_tokens': 0, 'input_tokens': 461, - 'output_tokens': 135, + 'output_tokens': 107, }, ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_013YNHomHSuhmzWZq8iWjvjK', + vendor_id='msg_0193srwo7TCx49h97wDwc7K7', ), ] ) +@pytest.mark.vcr() async def test_anthropic_prompted_json_output(allow_model_requests: None, anthropic_api_key: str): m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) @@ -1296,7 +1288,7 @@ async def get_user_country() -> str: ), ModelResponse( parts=[ - ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01FdQREaVXQbaH7JrFQaTzKb') + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_017UryVwtsKsjonhFV3cgV3X') ], usage=Usage( requests=1, @@ -1312,14 +1304,14 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01FjTTxFKgUXP2cp5s3d7fYh', + vendor_id='msg_014CpBKzioMqUyLWrMihpvsz', ), ModelRequest( parts=[ ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='toolu_01FdQREaVXQbaH7JrFQaTzKb', + tool_call_id='toolu_017UryVwtsKsjonhFV3cgV3X', timestamp=IsDatetime(), ) ], @@ -1347,12 +1339,13 @@ async def get_user_country() -> str: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01HrNfqrq9UGB54S5xhbCWY5', + vendor_id='msg_014JeWCouH6DpdqzMTaBdkpJ', ), ] ) +@pytest.mark.vcr() async def test_anthropic_prompted_json_output_multiple(allow_model_requests: None, anthropic_api_key: str): m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) @@ -1406,7 +1399,7 @@ class CountryLanguage(BaseModel): ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), - vendor_id='msg_01HyWKS3uRkhUw5mWKJY2iZN', + vendor_id='msg_013ttUi3HCcKt7PkJpoWs5FT', ), ] ) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index a62abf188..f9458f975 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -954,11 +954,7 @@ async def get_user_country() -> str: """, ), ModelResponse( - parts=[ - ToolCallPart( - tool_name='get_user_country', args={}, tool_call_id='pyd_ai_479a74a75212414fb3c7bd2242e9b669' - ) - ], + parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], usage=Usage( requests=1, request_tokens=123, @@ -975,7 +971,7 @@ async def get_user_country() -> str: ToolReturnPart( tool_name='get_user_country', content='Mexico', - tool_call_id='pyd_ai_479a74a75212414fb3c7bd2242e9b669', + tool_call_id=IsStr(), timestamp=IsDatetime(), ) ], @@ -1040,7 +1036,7 @@ class CountryLanguage(BaseModel): instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 8d2330593..03e9d46c5 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -538,6 +538,31 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): assert result.is_complete +async def test_stream_structured_json_schema_output(allow_model_requests: None): + stream = [ + chunk([]), + text_chunk('{"first": "One'), + text_chunk('", "second": "Two"'), + text_chunk('}'), + chunk([]), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m, output_type=JsonSchemaOutput(MyTypedDict)) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + + async def test_no_content(allow_model_requests: None): stream = [chunk([ChoiceDelta()]), chunk([ChoiceDelta()])] mock_client = MockOpenAI.create_mock_stream(stream) @@ -2224,7 +2249,7 @@ async def get_user_country() -> str: instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, @@ -2262,7 +2287,7 @@ async def get_user_country() -> str: instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, From 56e58f9016da635d63b915f9baefdbb2b397bfc0 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 15:31:30 +0000 Subject: [PATCH 15/90] Fix OpenAI tools --- pydantic_ai_slim/pydantic_ai/models/openai.py | 2 -- tests/models/test_gemini.py | 2 +- tests/models/test_google.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index f634ca686..cf08839b6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -273,9 +273,7 @@ async def _completions_create( openai_messages = await self._map_messages(messages) - tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] response_format: chat.completion_create_params.ResponseFormat | None = None - output_mode = model_request_parameters.output_mode if output_mode == 'json_schema': output_object = model_request_parameters.output_object diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 218a4af85..b65f4790f 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1549,7 +1549,7 @@ class CityLocation(BaseModel): @agent.tool_plain async def get_user_country() -> str: - return 'Mexico' + return 'Mexico' # pragma: no cover with pytest.raises(UserError, match='Google does not support JSON schema output and tools at the same time.'): await agent.run('What is the largest city in the user country?') diff --git a/tests/models/test_google.py b/tests/models/test_google.py index f9458f975..d606a9a80 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -757,7 +757,7 @@ class CityLocation(BaseModel): @agent.tool_plain async def get_user_country() -> str: - return 'Mexico' + return 'Mexico' # pragma: no cover with pytest.raises(UserError, match='Google does not support JSON schema output and tools at the same time.'): await agent.run('What is the largest city in the user country?') From a5234e10311e704b5d3ae5a265a6d25298ea8505 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 15:55:55 +0000 Subject: [PATCH 16/90] Improve test coverage --- pydantic_ai_slim/pydantic_ai/_output.py | 5 +++-- pydantic_ai_slim/pydantic_ai/models/openai.py | 12 ++++++------ tests/models/test_gemini.py | 2 ++ tests/models/test_google.py | 2 ++ tests/models/test_openai.py | 2 ++ tests/models/test_openai_responses.py | 2 ++ tests/test_agent.py | 12 +++++++++++- 7 files changed, 28 insertions(+), 9 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index c14d2c34a..ee739b41d 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -296,7 +296,8 @@ def __init__( if isinstance(text_output, TextOutput): self.text_output_schema = OutputTextSchema(text_output.output_type) - elif text_output is str: + else: + assert text_output is str self.text_output_schema = cast(OutputTextSchema[OutputDataT], OutputTextSchema(text_output)) elif len(tool_outputs) > 0: self.mode = 'tool' @@ -556,7 +557,7 @@ async def process( try: object_schema = self._object_schemas[kind] except KeyError as e: - raise ToolRetryError(_messages.RetryPromptPart(content=f'Invalid kind: {kind}')) from e + raise ToolRetryError(_messages.RetryPromptPart(content=f'Invalid kind: {kind}')) from e # pragma: no cover return await object_schema.process( data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index cf08839b6..41404d7ed 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -694,12 +694,12 @@ async def _responses_create( ): text = {'format': {'type': 'json_object'}} - if isinstance(instructions, str): - # Without this trick, we'd hit this error: - # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. - # Apparently they're only checking input messages for "JSON", not instructions. - openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) - instructions = NOT_GIVEN + # Without this trick, we'd hit this error: + # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. + # Apparently they're only checking input messages for "JSON", not instructions. + assert isinstance(instructions, str) + openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) + instructions = NOT_GIVEN try: extra_headers = model_settings.get('extra_headers', {}) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index b65f4790f..a783c36d2 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1560,6 +1560,8 @@ async def test_gemini_json_schema_output(allow_model_requests: None, gemini_api_ m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): + """A city and its country.""" + city: str country: str diff --git a/tests/models/test_google.py b/tests/models/test_google.py index d606a9a80..a2429e5b4 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -767,6 +767,8 @@ async def test_google_json_schema_output(allow_model_requests: None, google_prov m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): + """A city and its country.""" + city: str country: str diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 03e9d46c5..169e4b980 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -1956,6 +1956,8 @@ async def test_openai_json_schema_output(allow_model_requests: None, openai_api_ m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): + """A city and its country.""" + city: str country: str diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index 5ba449a9c..e01f9bcb9 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -662,6 +662,8 @@ async def test_json_schema_output(allow_model_requests: None, openai_api_key: st m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): + """A city and its country.""" + city: str country: str diff --git a/tests/test_agent.py b/tests/test_agent.py index 12b4fb700..11e9797eb 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -13,7 +13,7 @@ from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai._output import TextOutput, ToolOutput +from pydantic_ai._output import JsonSchemaOutput, TextOutput, ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.messages import ( BinaryContent, @@ -2724,3 +2724,13 @@ def foo_tool(foo: Foo) -> int: 'kind': 'request', } ) + + +def test_unsupported_output_mode(): + class Foo(BaseModel): + bar: str + + agent = Agent('test', output_type=JsonSchemaOutput(Foo)) + + with pytest.raises(UserError, match="Output mode 'json_schema' is not among supported modes: 'tool'"): + agent.run_sync('Hello') From 40def086dc2c6496eb2c433b9c96c0b6c9f51ce2 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 15:58:44 +0000 Subject: [PATCH 17/90] Update unsupported output mode error message --- pydantic_ai_slim/pydantic_ai/agent.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index c4632e804..62544ea20 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1655,9 +1655,8 @@ def _prepare_output_schema( if schema.mode is None: schema.mode = model_profile.default_output_mode if not schema.is_mode_supported(model_profile): - raise exceptions.UserError( - f"Output mode '{schema.mode}' is not among supported modes: {model_profile.output_modes}" - ) + modes = ', '.join(f"'{m}'" for m in model_profile.output_modes) + raise exceptions.UserError(f"Output mode '{schema.mode}' is not among supported modes: {modes}") return schema # pyright: ignore[reportReturnType] From 837d305db43f26c0824a76a8f563eb6d528acfb3 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 23:03:10 +0000 Subject: [PATCH 18/90] Improve test coverage --- pydantic_ai_slim/pydantic_ai/_output.py | 18 +- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- .../pydantic_ai/models/function.py | 24 ++- pydantic_ai_slim/pydantic_ai/models/openai.py | 4 +- tests/test_agent.py | 198 +++++++++++++++++- tests/test_streaming.py | 20 ++ 6 files changed, 250 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ee739b41d..6d84bf1db 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -287,7 +287,7 @@ def __init__( if len(text_outputs) > 0: if len(text_outputs) > 1: - raise UserError('Only one text output is allowed') + raise UserError('Only one text output is allowed.') text_output = text_outputs[0] self.mode = 'text' @@ -301,7 +301,7 @@ def __init__( self.text_output_schema = cast(OutputTextSchema[OutputDataT], OutputTextSchema(text_output)) elif len(tool_outputs) > 0: self.mode = 'tool' - else: + elif len(other_outputs) > 0: self.text_output_schema = self._build_text_output_schema( other_outputs, name=name, description=description, strict=strict ) @@ -361,7 +361,7 @@ def _build_text_output_schema( strict: bool | None = None, ) -> OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | None: if len(outputs) == 0: - return None + return None # pragma: no cover outputs = flatten_output_types(outputs) if len(outputs) == 1: @@ -466,9 +466,9 @@ class OutputObjectDefinition: def instructions(self) -> str: """Get instructions for model to output manual JSON matching the schema.""" schema = self.json_schema.copy() - if self.name and not schema.get('title'): + if self.name: schema['title'] = self.name - if self.description and not schema.get('description'): + if self.description: schema['description'] = self.description # Eventually move DEFAULT_PROMPTED_JSON_PROMPT to ModelProfile so it can be tweaked on a per model basis @@ -650,7 +650,7 @@ async def process( ) raise ToolRetryError(m) from e else: - raise + raise # pragma: lax no cover if k := self.outer_typed_dict_key: output = output[k] @@ -665,7 +665,7 @@ async def process( ) raise ToolRetryError(m) from r else: - raise + raise # pragma: lax no cover return output @@ -691,7 +691,7 @@ def __init__( elif output_type is str: return - raise ValueError('OutputTextSchema must take the `str` type or a function taking a `str`') + raise UserError('TextOutput must take the `str` type or a function taking a `str`') @property def object_def(self) -> None: @@ -716,7 +716,7 @@ async def process( ) raise ToolRetryError(m) from r else: - raise + raise # pragma: lax no cover return cast(OutputDataT, output) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 62544ea20..ca18baa46 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1007,7 +1007,7 @@ async def stream_to_final( if isinstance(new_part, _messages.TextPart): if output_schema.allow_text_output: return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart): + elif isinstance(new_part, _messages.ToolCallPart): # pragma: no branch for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) return None diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 22bcddffb..ce97e4196 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -11,6 +11,8 @@ from typing_extensions import TypeAlias, assert_never, overload +from pydantic_ai.profiles import ModelProfileSpec + from .. import _utils, usage from .._utils import PeekableAsyncStream from ..messages import ( @@ -48,14 +50,27 @@ class FunctionModel(Model): _system: str = field(default='function', repr=False) @overload - def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ... + def __init__( + self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None + ) -> None: ... @overload - def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ... + def __init__( + self, + *, + stream_function: StreamFunctionDef, + model_name: str | None = None, + profile: ModelProfileSpec | None = None, + ) -> None: ... @overload def __init__( - self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None + self, + function: FunctionDef, + *, + stream_function: StreamFunctionDef, + model_name: str | None = None, + profile: ModelProfileSpec | None = None, ) -> None: ... def __init__( @@ -64,6 +79,7 @@ def __init__( *, stream_function: StreamFunctionDef | None = None, model_name: str | None = None, + profile: ModelProfileSpec | None = None, ): """Initialize a `FunctionModel`. @@ -73,6 +89,7 @@ def __init__( function: The function to call for non-streamed requests. stream_function: The function to call for streamed requests. model_name: The name of the model. If not provided, a name is generated from the function names. + profile: The model profile to use. """ if function is None and stream_function is None: raise TypeError('Either `function` or `stream_function` must be provided') @@ -82,6 +99,7 @@ def __init__( function_name = self.function.__name__ if self.function is not None else '' stream_function_name = self.stream_function.__name__ if self.stream_function is not None else '' self._model_name = model_name or f'function:{function_name}:{stream_function_name}' + self._profile = profile async def request( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 41404d7ed..f46172c92 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -429,7 +429,7 @@ def _map_json_schema(o: OutputObjectDefinition) -> chat.completion_create_params } if o.description: response_format_param['json_schema']['description'] = o.description - if o.strict: + if o.strict: # pragma: no branch response_format_param['json_schema']['strict'] = o.strict return response_format_param @@ -820,7 +820,7 @@ def _map_json_schema(o: OutputObjectDefinition) -> responses.ResponseFormatTextJ } if o.description: response_format_param['description'] = o.description - if o.strict: + if o.strict: # pragma: no branch response_format_param['strict'] = o.strict return response_format_param diff --git a/tests/test_agent.py b/tests/test_agent.py index 11e9797eb..975500fe1 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -13,7 +13,7 @@ from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai._output import JsonSchemaOutput, TextOutput, ToolOutput +from pydantic_ai._output import JsonSchemaOutput, OutputType, PromptedJsonOutput, TextOutput, ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.messages import ( BinaryContent, @@ -31,6 +31,7 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition @@ -925,6 +926,24 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +@pytest.mark.parametrize( + 'output_type', + [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(str)]], +) +def test_output_type_multiple_text_output(output_type: OutputType[str]): + with pytest.raises(UserError, match='Only one text output is allowed.'): + Agent('test', output_type=output_type) + + +def test_output_type_text_output_invalid(): + def int_func(x: int) -> str: + return str(int) + + with pytest.raises(UserError, match='TextOutput must take the `str` type or a function taking a `str`'): + output_type: TextOutput[str] = TextOutput(int_func) # type: ignore + Agent('test', output_type=output_type) + + def test_output_type_async_function(): class Weather(BaseModel): temperature: float @@ -1235,6 +1254,183 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_output_type_prompted_json(): + def return_city_location(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + text = CityLocation(city='Mexico City', country='Mexico').model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + m = FunctionModel(return_city_location) + + class CityLocation(BaseModel): + """Description from docstring.""" + + city: str + country: str + + agent = Agent( + m, + output_type=PromptedJsonOutput( + CityLocation, name='City & Country', description='Description from PromptedJsonOutput' + ), + ) + + result = agent.run_sync('What is the capital of Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "City & Country", "type": "object", "description": "Description from PromptedJsonOutput. Description from docstring."} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63), + model_name='function:return_city_location:', + timestamp=IsDatetime(), + ), + ] + ) + + +def test_output_type_json_schema(): + def return_city_location(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + text = '{"city": "Mexico City"}' + else: + text = '{"city": "Mexico City", "country": "Mexico"}' + return ModelResponse(parts=[TextPart(content=text)]) + + m = FunctionModel(return_city_location, profile=ModelProfile(output_modes={'tool', 'json_schema'})) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent( + m, + output_type=JsonSchemaOutput(CityLocation), + ) + + result = agent.run_sync('What is the capital of Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City"}')], + usage=Usage(requests=1, request_tokens=56, response_tokens=5, total_tokens=61), + model_name='function:return_city_location:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content=[ + { + 'type': 'missing', + 'loc': ('country',), + 'msg': 'Field required', + 'input': {'city': 'Mexico City'}, + } + ], + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], + usage=Usage(requests=1, request_tokens=85, response_tokens=12, total_tokens=97), + model_name='function:return_city_location:', + timestamp=IsDatetime(), + ), + ] + ) + + +def test_output_type_prompted_json_function_with_retry(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + if city != 'Mexico City': + raise ModelRetry('City not found, I only know Mexico City') + return Weather(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + args_json = '{"city": "New York City"}' + else: + args_json = '{"city": "Mexico City"}' + + return ModelResponse(parts=[TextPart(content=args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=PromptedJsonOutput(get_weather)) + result = agent.run_sync('New York City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='New York City', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"additionalProperties": false, "properties": {"city": {"type": "string"}}, "required": ["city"], "type": "object", "title": "get_weather"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city": "New York City"}')], + usage=Usage(requests=1, request_tokens=53, response_tokens=6, total_tokens=59), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='City not found, I only know Mexico City', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City"}')], + usage=Usage(requests=1, request_tokens=68, response_tokens=11, total_tokens=79), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ] + ) + + def test_run_with_history_new(): m = TestModel() diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 12d29a9b0..125b3b6a6 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai._output import TextOutput from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( ModelMessage, @@ -192,6 +193,22 @@ async def test_streamed_text_stream(): ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] ) + def upcase(text: str) -> str: + return text.upper() + + async with agent.run_stream('Hello', output_type=TextOutput(upcase)) as result: + assert [c async for c in result.stream(debounce_by=None)] == snapshot( + [ + 'THE ', + 'THE CAT ', + 'THE CAT SAT ', + 'THE CAT SAT ON ', + 'THE CAT SAT ON THE ', + 'THE CAT SAT ON THE MAT.', + 'THE CAT SAT ON THE MAT.', + ] + ) + async with agent.run_stream('Hello') as result: assert [c async for c, _is_last in result.stream_structured(debounce_by=None)] == snapshot( [ @@ -921,3 +938,6 @@ def output_validator(data: OutputType | NotOutputType) -> OutputType | NotOutput async for output in stream.stream_output(debounce_by=None): outputs.append(output) assert outputs == [OutputType(value='a (validated)'), OutputType(value='a (validated)')] + + +# TODO: Test streaming structured output coming as text not tool calls From 5f71ba886d29ef6196a5f863b921a6d0f9e28906 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 23:42:53 +0000 Subject: [PATCH 19/90] Test streaming with structured text output --- pydantic_ai_slim/pydantic_ai/_output.py | 70 +++++++++++---------- pydantic_ai_slim/pydantic_ai/models/test.py | 2 +- pydantic_ai_slim/pydantic_ai/result.py | 7 ++- tests/test_agent.py | 9 +-- tests/test_streaming.py | 24 ++++++- 5 files changed, 70 insertions(+), 42 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 6d84bf1db..a8c597435 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -151,12 +151,9 @@ def __init__( @dataclass class TextOutput(Generic[OutputDataT]): - """Marker class to use text output for outputs.""" + """Marker class to use text output with an output function.""" - output_type: ( - Callable[[RunContext, str], Awaitable[OutputDataT] | OutputDataT] - | Callable[[str], Awaitable[OutputDataT] | OutputDataT] - ) + output_function: TextOutputFunction[OutputDataT] @dataclass(init=False) @@ -183,7 +180,7 @@ def __init__( class PromptedJsonOutput(Generic[OutputDataT]): - """Marker class to use manual JSON mode for outputs.""" + """Marker class to use prompted JSON mode for outputs.""" output_types: Sequence[OutputTypeOrFunction[OutputDataT]] name: str | None @@ -219,6 +216,15 @@ def __init__( type_params=(T_co,), ) +TextOutputFunction = TypeAliasType( + 'TextOutputFunction', + Union[ + Callable[[RunContext, str], Awaitable[T_co] | T_co], + Callable[[str], Awaitable[T_co] | T_co], + ], + type_params=(T_co,), +) + OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'prompted_json'] @@ -250,7 +256,6 @@ def __init__( if output_type is str: self.mode = 'text' - self.text_output_schema = OutputTextSchema(output_type) return if isinstance(output_type, JsonSchemaOutput): @@ -295,16 +300,15 @@ def __init__( self.mode = 'tool_or_text' if isinstance(text_output, TextOutput): - self.text_output_schema = OutputTextSchema(text_output.output_type) - else: - assert text_output is str - self.text_output_schema = cast(OutputTextSchema[OutputDataT], OutputTextSchema(text_output)) + self.text_output_schema = OutputTextSchema(text_output.output_function) elif len(tool_outputs) > 0: self.mode = 'tool' elif len(other_outputs) > 0: self.text_output_schema = self._build_text_output_schema( other_outputs, name=name, description=description, strict=strict ) + else: + raise UserError('No output type provided.') # pragma: no cover @staticmethod def _build_tools( @@ -435,7 +439,9 @@ async def process( Either the validated output data (left) or a retry message (right). """ assert self.allow_text_output is not False - assert self.text_output_schema is not None + + if self.text_output_schema is None: + return cast(OutputDataT, text) def strip_markdown_fences(text: str) -> str: if text.startswith('{'): @@ -672,26 +678,23 @@ async def process( @dataclass(init=False) class OutputTextSchema(Generic[OutputDataT]): - _function_schema: _function_schema.FunctionSchema | None = None - _str_argument_name: str | None = None + _function_schema: _function_schema.FunctionSchema + _str_argument_name: str def __init__( self, - output_type: type[OutputDataT] - | Callable[[RunContext[AgentDepsT], str], Awaitable[OutputDataT] | OutputDataT] - | Callable[[str], Awaitable[OutputDataT] | OutputDataT] = str, + output_function: TextOutputFunction[OutputDataT], ): - if inspect.isfunction(output_type) or inspect.ismethod(output_type): - self._function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) + if inspect.isfunction(output_function) or inspect.ismethod(output_function): + self._function_schema = _function_schema.function_schema(output_function, GenerateToolJsonSchema) + arguments_schema = self._function_schema.json_schema.get('properties', {}) argument_name = next(iter(arguments_schema.keys()), None) if argument_name and arguments_schema.get(argument_name, {}).get('type') == 'string': self._str_argument_name = argument_name return - elif output_type is str: - return - raise UserError('TextOutput must take the `str` type or a function taking a `str`') + raise UserError('TextOutput must take a function taking a `str`') @property def object_def(self) -> None: @@ -704,19 +707,18 @@ async def process( allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: - output = data + args = {self._str_argument_name: data} - if self._function_schema and self._str_argument_name: - try: - output = await self._function_schema.call({self._str_argument_name: output}, run_context) - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - content=r.message, - ) - raise ToolRetryError(m) from r - else: - raise # pragma: lax no cover + try: + output = await self._function_schema.call(args, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise # pragma: lax no cover return cast(OutputDataT, output) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 0daad25bc..8b8453ac6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -130,7 +130,7 @@ def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> l def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput: if self.custom_output_text is not None: - assert model_request_parameters.allow_text_output, ( + assert model_request_parameters.output_mode != 'tool', ( 'Plain response not allowed, but `custom_output_text` is set.' ) assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.' diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index ef2c82aa4..57174e27e 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -7,6 +7,7 @@ from datetime import datetime from typing import Generic +from pydantic import ValidationError from typing_extensions import TypeVar, assert_type, deprecated, overload from . import _utils, exceptions, messages as _messages, models @@ -306,7 +307,11 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Outp An async iterable of the response data. """ async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): - yield await self.validate_structured_output(structured_message, allow_partial=not is_last) + try: + yield await self.validate_structured_output(structured_message, allow_partial=not is_last) + except ValidationError: + if is_last: + raise # pragma: lax no cover async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. diff --git a/tests/test_agent.py b/tests/test_agent.py index 975500fe1..dcad1e9b9 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -873,7 +873,8 @@ class Weather(BaseModel): temperature: float description: str - def get_weather(city: str) -> Weather: + def get_weather(ctx: RunContext[None], city: str) -> Weather: + assert ctx is not None if city != 'Mexico City': raise ModelRetry('City not found, I only know Mexico City') return Weather(temperature=28.7, description='sunny') @@ -928,7 +929,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: @pytest.mark.parametrize( 'output_type', - [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(str)]], + [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(upcase)]], ) def test_output_type_multiple_text_output(output_type: OutputType[str]): with pytest.raises(UserError, match='Only one text output is allowed.'): @@ -937,9 +938,9 @@ def test_output_type_multiple_text_output(output_type: OutputType[str]): def test_output_type_text_output_invalid(): def int_func(x: int) -> str: - return str(int) + return str(int) # pragma: no cover - with pytest.raises(UserError, match='TextOutput must take the `str` type or a function taking a `str`'): + with pytest.raises(UserError, match='TextOutput must take a function taking a `str`'): output_type: TextOutput[str] = TextOutput(int_func) # type: ignore Agent('test', output_type=output_type) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 125b3b6a6..75935d33f 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai._output import TextOutput +from pydantic_ai._output import PromptedJsonOutput, TextOutput from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( ModelMessage, @@ -940,4 +940,24 @@ def output_validator(data: OutputType | NotOutputType) -> OutputType | NotOutput assert outputs == [OutputType(value='a (validated)'), OutputType(value='a (validated)')] -# TODO: Test streaming structured output coming as text not tool calls +async def test_stream_output_type_prompted_json(): + class CityLocation(BaseModel): + city: str + country: str | None = None + + m = TestModel(custom_output_text='{"city": "Mexico City", "country": "Mexico"}') + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream(debounce_by=None)] == snapshot( + [ + CityLocation(city='Mexico '), + CityLocation(city='Mexico City'), + CityLocation(city='Mexico City'), + CityLocation(city='Mexico City', country='Mexico'), + CityLocation(city='Mexico City', country='Mexico'), + ] + ) + assert result.is_complete From cfc274923ffa66a957ff79652147be1d830314dd Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 10 Jun 2025 23:53:41 +0000 Subject: [PATCH 20/90] Make TextOutputFunction Python 3.9 compatible --- pydantic_ai_slim/pydantic_ai/_output.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index a8c597435..dc1c2af40 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -219,8 +219,8 @@ def __init__( TextOutputFunction = TypeAliasType( 'TextOutputFunction', Union[ - Callable[[RunContext, str], Awaitable[T_co] | T_co], - Callable[[str], Awaitable[T_co] | T_co], + Callable[[RunContext, str], Union[Awaitable[T_co], T_co]], + Callable[[str], Union[Awaitable[T_co], T_co]], ], type_params=(T_co,), ) From a1376413a433164e0611481b1d5804c6ac28e4e3 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 11 Jun 2025 01:37:45 +0000 Subject: [PATCH 21/90] Properly merge JSON schemas accounting for defs --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/_output.py | 189 ++++++++++------- pydantic_ai_slim/pydantic_ai/_utils.py | 81 +++++++- pydantic_ai_slim/pydantic_ai/agent.py | 22 +- tests/models/test_anthropic.py | 2 +- tests/models/test_gemini.py | 2 +- tests/models/test_google.py | 2 +- tests/models/test_openai.py | 4 +- tests/models/test_openai_responses.py | 4 +- tests/test_agent.py | 78 ++++++- tests/test_utils.py | 203 ++++++++++++++++++- 11 files changed, 488 insertions(+), 101 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 9d9a20cb6..657de060b 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -841,7 +841,7 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], - output_type: _output.OutputType[OutputT], + output_type: _output.OutputSpec[OutputT], ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index dc1c2af40..33fcca216 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -160,7 +160,7 @@ class TextOutput(Generic[OutputDataT]): class JsonSchemaOutput(Generic[OutputDataT]): """Marker class to use JSON schema output for outputs.""" - output_types: Sequence[OutputTypeOrFunction[OutputDataT]] + outputs: Sequence[OutputTypeOrFunction[OutputDataT]] name: str | None description: str | None strict: bool | None @@ -173,7 +173,7 @@ def __init__( description: str | None = None, strict: bool | None = True, ): - self.output_types = flatten_output_types(type_) + self.outputs = flatten_output_spec(type_) self.name = name self.description = description self.strict = strict @@ -182,7 +182,7 @@ def __init__( class PromptedJsonOutput(Generic[OutputDataT]): """Marker class to use prompted JSON mode for outputs.""" - output_types: Sequence[OutputTypeOrFunction[OutputDataT]] + outputs: Sequence[OutputTypeOrFunction[OutputDataT]] name: str | None description: str | None @@ -193,7 +193,7 @@ def __init__( name: str | None = None, description: str | None = None, ): - self.output_types = flatten_output_types(type_) + self.outputs = flatten_output_spec(type_) self.name = name self.description = description @@ -203,8 +203,8 @@ def __init__( OutputTypeOrFunction = TypeAliasType( 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,) ) -OutputType = TypeAliasType( - 'OutputType', +OutputSpec = TypeAliasType( + 'OutputSpec', Union[ OutputTypeOrFunction[T_co], ToolOutput[T_co], @@ -230,20 +230,17 @@ def __init__( @dataclass(init=False) class OutputSchema(Generic[OutputDataT]): - """Model the final output from an agent run. - - Similar to `Tool` but for the final output of running an agent. - """ + """Model the final output from an agent run.""" mode: OutputMode | None = None text_output_schema: ( - OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | OutputTextSchema[OutputDataT] | None + OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | OutputFunctionSchema[OutputDataT] | None ) = None tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) def __init__( self, - output_type: OutputType[OutputDataT], + output_spec: OutputSpec[OutputDataT], *, name: str | None = None, description: str | None = None, @@ -254,39 +251,39 @@ def __init__( self.text_output_schema = None self.tools = {} - if output_type is str: + if output_spec is str: self.mode = 'text' return - if isinstance(output_type, JsonSchemaOutput): + if isinstance(output_spec, JsonSchemaOutput): self.mode = 'json_schema' self.text_output_schema = self._build_text_output_schema( - output_type.output_types, - name=output_type.name, - description=output_type.description, - strict=output_type.strict, + output_spec.outputs, + name=output_spec.name, + description=output_spec.description, + strict=output_spec.strict, ) return - if isinstance(output_type, PromptedJsonOutput): + if isinstance(output_spec, PromptedJsonOutput): self.mode = 'prompted_json' self.text_output_schema = self._build_text_output_schema( - output_type.output_types, name=output_type.name, description=output_type.description + output_spec.outputs, name=output_spec.name, description=output_spec.description ) return text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = [] - for output_type_or_marker in flatten_output_types(output_type): - if output_type_or_marker is str: - text_outputs.append(cast(type[str], output_type_or_marker)) - elif isinstance(output_type_or_marker, TextOutput): - text_outputs.append(output_type_or_marker) - elif isinstance(output_type_or_marker, ToolOutput): - tool_outputs.append(output_type_or_marker) + for output in flatten_output_spec(output_spec): + if output is str: + text_outputs.append(cast(type[str], output)) + elif isinstance(output, TextOutput): + text_outputs.append(output) + elif isinstance(output, ToolOutput): + tool_outputs.append(output) else: - other_outputs.append(output_type_or_marker) + other_outputs.append(output) self.tools = self._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) @@ -300,7 +297,7 @@ def __init__( self.mode = 'tool_or_text' if isinstance(text_output, TextOutput): - self.text_output_schema = OutputTextSchema(text_output.output_function) + self.text_output_schema = OutputFunctionSchema(text_output.output_function) elif len(tool_outputs) > 0: self.mode = 'tool' elif len(other_outputs) > 0: @@ -352,7 +349,7 @@ def _build_tools( if strict is None: strict = default_strict - parameters_schema = OutputObjectSchema(output_type=output_type, description=description, strict=strict) + parameters_schema = OutputObjectSchema(output=output_type, description=description, strict=strict) tools[name] = OutputTool(name=name, parameters_schema=parameters_schema, multiple=multiple) return tools @@ -367,11 +364,11 @@ def _build_text_output_schema( if len(outputs) == 0: return None # pragma: no cover - outputs = flatten_output_types(outputs) + outputs = flatten_output_spec(outputs) if len(outputs) == 1: - return OutputObjectSchema(output_type=outputs[0], name=name, description=description, strict=strict) + return OutputObjectSchema(output=outputs[0], name=name, description=description, strict=strict) - return OutputUnionSchema(output_types=outputs, strict=strict) + return OutputUnionSchema(outputs=outputs, strict=strict, name=name, description=description) @property def allow_text_output(self) -> Literal['plain', 'json', False]: @@ -501,48 +498,81 @@ class OutputUnionSchema(Generic[OutputDataT]): def __init__( self, - output_types: Sequence[OutputTypeOrFunction[OutputDataT]], + outputs: Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, strict: bool | None = None, ): + json_schemas: list[ObjectJsonSchema] = [] self._object_schemas = {} - # TODO: Ensure keys are unique - self._object_schemas = { - output_type.__name__: OutputObjectSchema(output_type=output_type, strict=strict) - for output_type in output_types - } + for output in outputs: + object_schema = OutputObjectSchema(output=output, strict=strict) + object_def = object_schema.object_def - self._root_object_schema = OutputObjectSchema(output_type=OutputUnionData) + object_key = object_def.name or output.__name__ + i = 1 + original_key = object_key + while object_key in self._object_schemas: + i += 1 + object_key = f'{original_key}_{i}' + + self._object_schemas[object_key] = object_schema + + json_schema = object_def.json_schema + if object_name := object_def.name: + json_schema['title'] = object_name + if object_description := object_def.description: + json_schema['description'] = object_description + + json_schemas.append(json_schema) + + json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas) + + discriminated_json_schemas: list[ObjectJsonSchema] = [] + for object_key, json_schema in zip(self._object_schemas.keys(), json_schemas): + title = json_schema.pop('title', None) + description = json_schema.pop('description', None) + + discriminated_json_schema = { + 'type': 'object', + 'properties': { + 'kind': { + 'type': 'string', + 'const': object_key, + }, + 'data': json_schema, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + } + if title: + discriminated_json_schema['title'] = title + if description: + discriminated_json_schema['description'] = description + + discriminated_json_schemas.append(discriminated_json_schema) + + self._root_object_schema = OutputObjectSchema(output=OutputUnionData) - # TODO: Account for conflicting $defs and $refs json_schema = { 'type': 'object', 'properties': { 'result': { - 'anyOf': [ - { - 'type': 'object', - 'properties': { - 'kind': { - 'type': 'string', - 'const': name, - }, - 'data': object_schema.object_def.json_schema, # TODO: Pop description here? - }, - 'description': object_schema.object_def.description or name, # TODO: Better description - 'required': ['kind', 'data'], - 'additionalProperties': False, - } - for name, object_schema in self._object_schemas.items() - ], + 'anyOf': discriminated_json_schemas, } }, 'required': ['result'], 'additionalProperties': False, } + if all_defs: + json_schema['$defs'] = all_defs self.object_def = OutputObjectDefinition( json_schema=json_schema, strict=strict, + name=name, + description=description, ) async def process( @@ -552,7 +582,6 @@ async def process( allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: - # TODO: Error handling? result = await self._root_object_schema.process( data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -563,7 +592,11 @@ async def process( try: object_schema = self._object_schemas[kind] except KeyError as e: - raise ToolRetryError(_messages.RetryPromptPart(content=f'Invalid kind: {kind}')) from e # pragma: no cover + if wrap_validation_errors: + m = _messages.RetryPromptPart(content=f'Invalid kind: {kind}') + raise ToolRetryError(m) from e + else: + raise # pragma: lax no cover return await object_schema.process( data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors @@ -579,26 +612,26 @@ class OutputObjectSchema(Generic[OutputDataT]): def __init__( self, - output_type: OutputTypeOrFunction[OutputDataT], + output: OutputTypeOrFunction[OutputDataT], *, name: str | None = None, description: str | None = None, strict: bool | None = None, ): - if inspect.isfunction(output_type) or inspect.ismethod(output_type): - self._function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) + if inspect.isfunction(output) or inspect.ismethod(output): + self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema) self._validator = self._function_schema.validator json_schema = self._function_schema.json_schema json_schema['description'] = self._function_schema.description else: type_adapter: TypeAdapter[Any] - if _utils.is_model_like(output_type): - type_adapter = TypeAdapter(output_type) + if _utils.is_model_like(output): + type_adapter = TypeAdapter(output) else: self.outer_typed_dict_key = 'response' response_data_typed_dict = TypedDict( # noqa: UP013 'response_data_typed_dict', - {'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm] + {'response': cast(type[OutputDataT], output)}, # pyright: ignore[reportInvalidTypeForm] ) type_adapter = TypeAdapter(response_data_typed_dict) @@ -619,7 +652,7 @@ def __init__( description = f'{description}. {json_schema_description}' self.object_def = OutputObjectDefinition( - name=name or getattr(output_type, '__name__', None), + name=name or getattr(output, '__name__', None), description=description, json_schema=json_schema, strict=strict, @@ -677,7 +710,7 @@ async def process( @dataclass(init=False) -class OutputTextSchema(Generic[OutputDataT]): +class OutputFunctionSchema(Generic[OutputDataT]): _function_schema: _function_schema.FunctionSchema _str_argument_name: str @@ -804,17 +837,17 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return () -def flatten_output_types(output_type: T | Sequence[T]) -> list[T]: - output_types: Sequence[T] - if isinstance(output_type, Sequence): - output_types = output_type +def flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: + outputs: Sequence[T] + if isinstance(output_spec, Sequence): + outputs = output_spec else: - output_types = (output_type,) + outputs = (output_spec,) - output_types_flat: list[T] = [] - for output_type in output_types: - if union_types := get_union_args(output_type): - output_types_flat.extend(union_types) + outputs_flat: list[T] = [] + for output in outputs: + if union_types := get_union_args(output): + outputs_flat.extend(union_types) else: - output_types_flat.append(output_type) - return output_types_flat + outputs_flat.append(output) + return outputs_flat diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 77c34fbac..a280115f8 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -9,7 +9,7 @@ from datetime import datetime, timezone from functools import partial from types import GenericAlias -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast from anyio.to_thread import run_sync from pydantic import BaseModel, TypeAdapter @@ -302,3 +302,82 @@ def dataclasses_no_defaults_repr(self: Any) -> str: def number_to_datetime(x: int | float) -> datetime: return TypeAdapter(datetime).validate_python(x) + + +def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None: + """Update $refs in a schema to use the new names from name_mapping.""" + if '$ref' in s: + ref = s['$ref'] + if ref.startswith('#/$defs/'): + original_name = ref[8:] # Remove '#/$defs/' + new_name = name_mapping.get(original_name, original_name) + s['$ref'] = f'#/$defs/{new_name}' + + # Recursively update refs in properties + if 'properties' in s: + props: dict[str, Any] = s['properties'] + for prop in props.values(): + if isinstance(prop, dict): + prop = cast(dict[str, Any], prop) + _update_mapped_json_schema_refs(prop, name_mapping) + + # Handle arrays + if 'items' in s and isinstance(s['items'], dict): + items: dict[str, Any] = s['items'] + _update_mapped_json_schema_refs(items, name_mapping) + if 'prefixItems' in s: + prefix_items: list[dict[str, Any]] = s['prefixItems'] + for item in prefix_items: + if isinstance(item, dict): + _update_mapped_json_schema_refs(item, name_mapping) + + # Handle unions + for union_type in ['anyOf', 'oneOf']: + if union_type in s: + union_items: list[dict[str, Any]] = s[union_type] + for item in union_items: + if isinstance(item, dict): + _update_mapped_json_schema_refs(item, name_mapping) + + +def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]: + """Merges the `$defs` from different JSON schemas into a single deduplicated `$defs`, handling name collisions of `$defs` that are not the same, and rewrites `$ref`s to point to the new `$defs`. + + Returns a tuple of the rewritten schemas and a dictionary of the new `$defs`. + """ + all_defs: dict[str, dict[str, Any]] = {} + rewritten_schemas: list[dict[str, Any]] = [] + + for schema in schemas: + if '$defs' not in schema: + rewritten_schemas.append(schema) + continue + + schema = schema.copy() + defs = schema.pop('$defs', None) + schema_name_mapping: dict[str, str] = {} + + # Process definitions and build mapping + for name, def_schema in defs.items(): + if name not in all_defs: + all_defs[name] = def_schema + schema_name_mapping[name] = name + elif def_schema != all_defs[name]: + new_name = name + if title := schema.get('title'): + new_name = f'{title}_{name}' + + i = 1 + original_new_name = new_name + new_name = f'{new_name}_{i}' + while new_name in all_defs: + i += 1 + new_name = f'{original_new_name}_{i}' + + all_defs[new_name] = def_schema + schema_name_mapping[name] = new_name + + _update_mapped_json_schema_refs(schema, schema_name_mapping) + rewritten_schemas.append(schema) + + return rewritten_schemas, all_defs diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index ca18baa46..a4ea05329 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -128,7 +128,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ - output_type: _output.OutputType[OutputDataT] + output_type: _output.OutputSpec[OutputDataT] """ The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. """ @@ -163,7 +163,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - output_type: _output.OutputType[OutputDataT] = str, + output_type: _output.OutputSpec[OutputDataT] = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] @@ -377,7 +377,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT], + output_type: _output.OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -407,7 +407,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -495,7 +495,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, - output_type: _output.OutputType[RunOutputDataT], + output_type: _output.OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -527,7 +527,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -781,7 +781,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -811,7 +811,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -894,7 +894,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent], *, - output_type: _output.OutputType[RunOutputDataT], + output_type: _output.OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -925,7 +925,7 @@ async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -1639,7 +1639,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') def _prepare_output_schema( - self, output_type: _output.OutputType[RunOutputDataT] | None, model_profile: ModelProfile + self, output_type: _output.OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile ) -> _output.OutputSchema[RunOutputDataT]: if output_type is not None: if self._output_validators: diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 068c8c116..c5d769e79 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1374,7 +1374,7 @@ class CountryLanguage(BaseModel): instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a783c36d2..c8879e33f 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1832,7 +1832,7 @@ class CountryLanguage(BaseModel): instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, diff --git a/tests/models/test_google.py b/tests/models/test_google.py index a2429e5b4..daa480f4e 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -1038,7 +1038,7 @@ class CountryLanguage(BaseModel): instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 169e4b980..e82322687 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -2251,7 +2251,7 @@ async def get_user_country() -> str: instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, @@ -2289,7 +2289,7 @@ async def get_user_country() -> str: instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index e01f9bcb9..524537729 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -912,7 +912,7 @@ async def get_user_country() -> str: instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, @@ -943,7 +943,7 @@ async def get_user_country() -> str: instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} Don't include any text or Markdown fencing before or after.\ """, diff --git a/tests/test_agent.py b/tests/test_agent.py index dcad1e9b9..751b5036e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -13,7 +13,7 @@ from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai._output import JsonSchemaOutput, OutputType, PromptedJsonOutput, TextOutput, ToolOutput +from pydantic_ai._output import JsonSchemaOutput, OutputSpec, PromptedJsonOutput, TextOutput, ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.messages import ( BinaryContent, @@ -931,7 +931,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'output_type', [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(upcase)]], ) -def test_output_type_multiple_text_output(output_type: OutputType[str]): +def test_output_type_multiple_text_output(output_type: OutputSpec[str]): with pytest.raises(UserError, match='Only one text output is allowed.'): Agent('test', output_type=output_type) @@ -1304,6 +1304,80 @@ class CityLocation(BaseModel): ) +def test_output_type_prompted_json_with_defs(): + class Foo(BaseModel): + """Foo description""" + + foo: str + + class Bar(BaseModel): + """Bar description""" + + bar: str + + class Baz(BaseModel): + """Baz description""" + + baz: str + + class FooBar(BaseModel): + """FooBar description""" + + foo: Foo + bar: Bar + + class FooBaz(BaseModel): + """FooBaz description""" + + foo: Foo + baz: Baz + + def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + text = '{"result": {"kind": "FooBar", "data": {"foo": {"foo": "foo"}, "bar": {"bar": "bar"}}}}' + return ModelResponse(parts=[TextPart(content=text)]) + + m = FunctionModel(return_foo_bar) + + agent = Agent( + m, + output_type=PromptedJsonOutput( + [FooBar, FooBaz], name='FooBar or FooBaz', description='FooBar or FooBaz description' + ), + ) + + result = agent.run_sync('What is foo?') + assert result.output == snapshot(FooBar(foo=Foo(foo='foo'), bar=Bar(bar='bar'))) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is foo?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "FooBar"}, "data": {"properties": {"foo": {"$ref": "#/$defs/Foo"}, "bar": {"$ref": "#/$defs/Bar"}}, "required": ["foo", "bar"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "FooBar", "description": "FooBar description"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "FooBaz"}, "data": {"properties": {"foo": {"$ref": "#/$defs/Foo"}, "baz": {"$ref": "#/$defs/Baz"}}, "required": ["foo", "baz"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "FooBaz", "description": "FooBaz description"}]}}, "required": ["result"], "additionalProperties": false, "$defs": {"Bar": {"description": "Bar description", "properties": {"bar": {"type": "string"}}, "required": ["bar"], "title": "Bar", "type": "object"}, "Foo": {"description": "Foo description", "properties": {"foo": {"type": "string"}}, "required": ["foo"], "title": "Foo", "type": "object"}, "Baz": {"description": "Baz description", "properties": {"baz": {"type": "string"}}, "required": ["baz"], "title": "Baz", "type": "object"}}, "title": "FooBar or FooBaz", "description": "FooBaz description"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result": {"kind": "FooBar", "data": {"foo": {"foo": "foo"}, "bar": {"bar": "bar"}}}}' + ) + ], + usage=Usage(requests=1, request_tokens=53, response_tokens=17, total_tokens=70), + model_name='function:return_foo_bar:', + timestamp=IsDatetime(), + ), + ] + ) + + def test_output_type_json_schema(): def return_city_location(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: if len(messages) == 1: diff --git a/tests/test_utils.py b/tests/test_utils.py index e7d3ddcf3..fdd042b46 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,14 @@ from inline_snapshot import snapshot from pydantic_ai import UserError -from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal, run_in_executor +from pydantic_ai._utils import ( + UNSET, + PeekableAsyncStream, + check_object_json_schema, + group_by_temporal, + merge_json_schema_defs, + run_in_executor, +) from .models.mock_async_stream import MockAsyncStream @@ -153,3 +160,197 @@ async def test_run_in_executor_with_contextvars() -> None: # show that the old version did not work old_result = asyncio.get_running_loop().run_in_executor(None, ctx_var.get) assert old_result != ctx_var.get() + + +def test_merge_json_schema_defs(): + foo_bar_schema = { + '$defs': { + 'Bar': { + 'description': 'Bar description', + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + 'Foo': { + 'description': 'Foo description', + 'properties': {'foo': {'type': 'string'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + }, + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + 'title': 'FooBar', + } + + foo_bar_baz_schema = { + '$defs': { + 'Baz': { + 'description': 'Baz description', + 'properties': {'baz': {'type': 'string'}}, + 'required': ['baz'], + 'title': 'Baz', + 'type': 'object', + }, + 'Foo': { + 'description': 'Foo description. Note that this is different from the Foo in foo_bar_schema!', + 'properties': {'foo': {'type': 'int'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar': { + 'description': 'Bar description', + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + }, + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'baz': {'$ref': '#/$defs/Baz'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'baz', 'bar'], + 'type': 'object', + 'title': 'FooBarBaz', + } + + # A schema with no title that will cause numeric suffixes + no_title_schema = { + '$defs': { + 'Foo': { + 'description': 'Another different Foo', + 'properties': {'foo': {'type': 'boolean'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar': { + 'description': 'Another different Bar', + 'properties': {'bar': {'type': 'number'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + }, + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + } + + # Another schema with no title that will cause more numeric suffixes + another_no_title_schema = { + '$defs': { + 'Foo': { + 'description': 'Yet another different Foo', + 'properties': {'foo': {'type': 'array'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar': { + 'description': 'Yet another different Bar', + 'properties': {'bar': {'type': 'object'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + }, + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + } + + schemas = [foo_bar_schema, foo_bar_baz_schema, no_title_schema, another_no_title_schema] + rewritten_schemas, all_defs = merge_json_schema_defs(schemas) + assert all_defs == snapshot( + { + 'Bar': { + 'description': 'Bar description', + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + 'Foo': { + 'description': 'Foo description', + 'properties': {'foo': {'type': 'string'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Baz': { + 'description': 'Baz description', + 'properties': {'baz': {'type': 'string'}}, + 'required': ['baz'], + 'title': 'Baz', + 'type': 'object', + }, + 'FooBarBaz_Foo_1': { + 'description': 'Foo description. Note that this is different from the Foo in foo_bar_schema!', + 'properties': {'foo': {'type': 'int'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Foo_1': { + 'description': 'Another different Foo', + 'properties': {'foo': {'type': 'boolean'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar_1': { + 'description': 'Another different Bar', + 'properties': {'bar': {'type': 'number'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + 'Foo_2': { + 'description': 'Yet another different Foo', + 'properties': {'foo': {'type': 'array'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar_2': { + 'description': 'Yet another different Bar', + 'properties': {'bar': {'type': 'object'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + } + ) + assert rewritten_schemas == snapshot( + [ + { + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + 'title': 'FooBar', + }, + { + 'properties': { + 'foo': {'$ref': '#/$defs/FooBarBaz_Foo_1'}, + 'baz': {'$ref': '#/$defs/Baz'}, + 'bar': {'$ref': '#/$defs/Bar'}, + }, + 'required': ['foo', 'baz', 'bar'], + 'type': 'object', + 'title': 'FooBarBaz', + }, + { + 'properties': {'foo': {'$ref': '#/$defs/Foo_1'}, 'bar': {'$ref': '#/$defs/Bar_1'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + }, + { + 'properties': {'foo': {'$ref': '#/$defs/Foo_2'}, 'bar': {'$ref': '#/$defs/Bar_2'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + }, + ] + ) From f495d4693c40a8bb84b49a651ef3473594423917 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 12 Jun 2025 00:06:29 +0000 Subject: [PATCH 22/90] Refactor output schemas and modes: more 'isinstance(output_schema, ...)', less 'output_schema.mode == ...' --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 45 +- pydantic_ai_slim/pydantic_ai/_output.py | 644 ++++++++++++------ pydantic_ai_slim/pydantic_ai/agent.py | 36 +- .../pydantic_ai/models/anthropic.py | 2 +- .../pydantic_ai/models/bedrock.py | 2 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 2 +- .../pydantic_ai/models/mistral.py | 2 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 4 +- .../pydantic_ai/profiles/__init__.py | 20 +- pydantic_ai_slim/pydantic_ai/result.py | 30 +- tests/test_agent.py | 14 +- 13 files changed, 520 insertions(+), 285 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 657de060b..44602715e 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -260,14 +260,12 @@ async def add_mcp_server_tools(server: MCPServer) -> None: function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] output_schema = ctx.deps.output_schema - assert output_schema.mode is not None # Should have been set in agent._prepare_output_schema - return models.ModelRequestParameters( function_tools=function_tool_defs, output_mode=output_schema.mode, - output_object=output_schema.text_output_schema.object_def if output_schema.text_output_schema else None, - output_tools=output_schema.tool_defs(), - allow_text_output=output_schema.allow_text_output == 'plain', + output_object=output_schema.object_def if isinstance(output_schema, _output.JsonTextOutputSchema) else None, + output_tools=output_schema.tool_defs() if isinstance(output_schema, _output.ToolOutputSchema) else [], + allow_text_output=isinstance(output_schema, _output.TextOutputSchema), ) @@ -452,7 +450,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # when the model has already returned text along side tool calls # in this scenario, if text responses are allowed, we return text from the most recent model # response, if any - if ctx.deps.output_schema.allow_text_output: + if isinstance(ctx.deps.output_schema, _output.TextOutputSchema): for message in reversed(ctx.state.message_history): if isinstance(message, _messages.ModelResponse): last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)] @@ -475,21 +473,23 @@ async def _handle_tool_calls( output_schema = ctx.deps.output_schema run_context = build_run_context(ctx) - # first, look for the output tool call final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] - for call, output_tool in output_schema.find_tool(tool_calls): - try: - result_data = await output_tool.process(call, run_context) - result_data = await _validate_output(result_data, ctx, call) - except _output.ToolRetryError as e: - # TODO: Should only increment retry stuff once per node execution, not for each tool call - # Also, should increment the tool-specific retry count rather than the run retry count - ctx.state.increment_retries(ctx.deps.max_result_retries, e) - parts.append(e.tool_retry) - else: - final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - break + + # first, look for the output tool call + if isinstance(output_schema, _output.ToolOutputSchema): + for call, output_tool in output_schema.find_tool(tool_calls): + try: + result_data = await output_tool.process(call, run_context) + result_data = await _validate_output(result_data, ctx, call) + except _output.ToolRetryError as e: + # TODO: Should only increment retry stuff once per node execution, not for each tool call + # Also, should increment the tool-specific retry count rather than the run retry count + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + parts.append(e.tool_retry) + else: + final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) + break # Then build the other request parts based on end strategy tool_responses: list[_messages.ModelRequestPart] = self._tool_responses @@ -535,7 +535,7 @@ async def _handle_text_response( text = '\n\n'.join(texts) try: - if output_schema.allow_text_output: + if isinstance(output_schema, _output.TextOutputSchema): run_context = build_run_context(ctx) result_data = await output_schema.process(text, run_context) else: @@ -765,7 +765,10 @@ def _unknown_tool( ) -> _messages.RetryPromptPart: ctx.state.increment_retries(ctx.deps.max_result_retries) tool_names = list(ctx.deps.function_tools.keys()) - tool_names.extend(ctx.deps.output_schema.tool_names()) + + output_schema = ctx.deps.output_schema + if isinstance(output_schema, _output.ToolOutputSchema): + tool_names.extend(output_schema.tool_names()) if tool_names: msg = f'Available tools: {", ".join(tool_names)}' diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 33fcca216..471bfa7d4 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -3,19 +3,17 @@ import inspect import json import re +from abc import ABC, abstractmethod from collections.abc import Awaitable, Iterable, Iterator, Sequence from dataclasses import dataclass, field -from textwrap import dedent -from typing import Any, Callable, Generic, Literal, Union, cast +from typing import Any, Callable, Generic, Literal, Union, cast, overload from pydantic import TypeAdapter, ValidationError from pydantic_core import SchemaValidator -from typing_extensions import TypeAliasType, TypedDict, TypeVar, get_args, get_origin +from typing_extensions import TypeAliasType, TypedDict, TypeVar, assert_never, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin -from pydantic_ai.profiles import ModelProfile - from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry, UserError from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition @@ -55,15 +53,6 @@ DEFAULT_OUTPUT_TOOL_NAME = 'final_result' DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' -DEFAULT_PROMPTED_JSON_PROMPT = dedent( - """ - Always respond with a JSON object that's compatible with this schema: - - {schema} - - Don't include any text or Markdown fencing before or after. - """ -) @dataclass @@ -225,52 +214,94 @@ def __init__( type_params=(T_co,), ) -OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'prompted_json'] + +OutputMode = Literal['text', 'tool', 'json_schema', 'prompted_json', 'tool_or_text'] +"""All output modes.""" +SupportableOutputMode = Literal['tool', 'json_schema'] +"""Output modes that require specific support by a model (class). Used by ModelProfile.output_modes""" +StructuredOutputMode = Literal['tool', 'json_schema', 'prompted_json'] +"""Output modes that can be used for any structured output. Used by ModelProfile.default_output_mode""" + + +class BaseOutputSchema(ABC, Generic[OutputDataT]): + @property + @abstractmethod + def mode(self) -> OutputMode | None: + raise NotImplementedError() + + @abstractmethod + def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: + raise NotImplementedError() + + @abstractmethod + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + raise NotImplementedError() + + @property + def tools(self) -> dict[str, OutputTool[OutputDataT]]: + """Get the tools for this output schema.""" + return {} @dataclass(init=False) -class OutputSchema(Generic[OutputDataT]): +class OutputSchema(BaseOutputSchema[OutputDataT], ABC): """Model the final output from an agent run.""" - mode: OutputMode | None = None - text_output_schema: ( - OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | OutputFunctionSchema[OutputDataT] | None - ) = None - tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + @classmethod + @overload + def build( + cls, + output_spec: OutputSpec[OutputDataT], + *, + default_mode: StructuredOutputMode, + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> OutputSchema[OutputDataT]: ... - def __init__( - self, + @classmethod + @overload + def build( + cls, output_spec: OutputSpec[OutputDataT], *, + default_mode: None = None, name: str | None = None, description: str | None = None, strict: bool | None = None, - ): - """Build an OutputSchema dataclass from an output type.""" - self.mode = None - self.text_output_schema = None - self.tools = {} + ) -> OutputSchemaWithoutMode[OutputDataT]: ... + @classmethod + def build( + cls, + output_spec: OutputSpec[OutputDataT], + *, + default_mode: StructuredOutputMode | None = None, + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> BaseOutputSchema[OutputDataT]: + """Build an OutputSchema dataclass from an output type.""" if output_spec is str: - self.mode = 'text' - return + return PlainTextOutputSchema() if isinstance(output_spec, JsonSchemaOutput): - self.mode = 'json_schema' - self.text_output_schema = self._build_text_output_schema( - output_spec.outputs, - name=output_spec.name, - description=output_spec.description, - strict=output_spec.strict, + return JsonSchemaOutputSchema( + text_processor=cls._build_text_processor( + output_spec.outputs, + name=output_spec.name, + description=output_spec.description, + strict=output_spec.strict, + ), ) - return if isinstance(output_spec, PromptedJsonOutput): - self.mode = 'prompted_json' - self.text_output_schema = self._build_text_output_schema( - output_spec.outputs, name=output_spec.name, description=output_spec.description + return PromptedJsonOutputSchema( + text_processor=cls._build_text_processor( + output_spec.outputs, name=output_spec.name, description=output_spec.description + ), ) - return text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] @@ -285,27 +316,37 @@ def __init__( else: other_outputs.append(output) - self.tools = self._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) + tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) if len(text_outputs) > 0: if len(text_outputs) > 1: raise UserError('Only one text output is allowed.') text_output = text_outputs[0] - self.mode = 'text' - if len(self.tools) > 0: - self.mode = 'tool_or_text' - + text_output_schema = None if isinstance(text_output, TextOutput): - self.text_output_schema = OutputFunctionSchema(text_output.output_function) - elif len(tool_outputs) > 0: - self.mode = 'tool' - elif len(other_outputs) > 0: - self.text_output_schema = self._build_text_output_schema( - other_outputs, name=name, description=description, strict=strict + text_output_schema = PlainTextOutputProcessor(text_output.output_function) + + if len(tools) == 0: + return PlainTextOutputSchema(text_processor=text_output_schema) + else: + return ToolOrTextOutputSchema(text_processor=text_output_schema, tools=tools) + + if len(tool_outputs) > 0: + return ToolOutputSchema(tools=tools) + + if len(other_outputs) > 0: + schema = OutputSchemaWithoutMode( + text_processor=cls._build_text_processor( + other_outputs, name=name, description=description, strict=strict + ), + tools=tools, ) - else: - raise UserError('No output type provided.') # pragma: no cover + if default_mode: + schema = schema.with_default_mode(default_mode) + return schema + + raise UserError('No output type provided.') # pragma: no cover @staticmethod def _build_tools( @@ -349,73 +390,99 @@ def _build_tools( if strict is None: strict = default_strict - parameters_schema = OutputObjectSchema(output=output_type, description=description, strict=strict) + parameters_schema = ObjectOutputProcessor(output=output_type, description=description, strict=strict) tools[name] = OutputTool(name=name, parameters_schema=parameters_schema, multiple=multiple) return tools @staticmethod - def _build_text_output_schema( + def _build_text_processor( outputs: Sequence[OutputTypeOrFunction[OutputDataT]], name: str | None = None, description: str | None = None, strict: bool | None = None, - ) -> OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | None: - if len(outputs) == 0: - return None # pragma: no cover - + ) -> ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]: outputs = flatten_output_spec(outputs) if len(outputs) == 1: - return OutputObjectSchema(output=outputs[0], name=name, description=description, strict=strict) + return ObjectOutputProcessor(output=outputs[0], name=name, description=description, strict=strict) - return OutputUnionSchema(outputs=outputs, strict=strict, name=name, description=description) + return UnionOutputProcessor(outputs=outputs, strict=strict, name=name, description=description) @property - def allow_text_output(self) -> Literal['plain', 'json', False]: - """Whether the model allows text output.""" - if self.mode == 'tool': - return False - if self.mode in ('text', 'tool_or_text'): - return 'plain' - return 'json' - - def is_mode_supported(self, profile: ModelProfile) -> bool: - """Whether the model supports the output mode.""" - mode = self.mode - if mode in ('text', 'prompted_json'): - return True - if self.mode == 'tool_or_text': - mode = 'tool' - return mode in profile.output_modes + @abstractmethod + def mode(self) -> OutputMode: + raise NotImplementedError() - def find_named_tool( - self, parts: Iterable[_messages.ModelResponsePart], tool_name: str - ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: - """Find a tool that matches one of the calls, with a specific name.""" - for part in parts: # pragma: no branch - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if part.tool_name == tool_name: - return part, self.tools[tool_name] + def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: + return self - def find_tool( + +@dataclass(init=False) +class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): + text_processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] + _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + + def __init__( self, - parts: Iterable[_messages.ModelResponsePart], - ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: - """Find a tool that matches one of the calls.""" - for part in parts: - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if result := self.tools.get(part.tool_name): - yield part, result + text_processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], + tools: dict[str, OutputTool[OutputDataT]], + ): + self.text_processor = text_processor + self._tools = tools - def tool_names(self) -> list[str]: - """Return the names of the tools.""" - return list(self.tools.keys()) + @property + def mode(self) -> None: + return None - def tool_defs(self) -> list[ToolDefinition]: - """Get tool definitions to register with the model.""" - if self.mode not in ('tool', 'tool_or_text'): - return [] - return [t.tool_def for t in self.tools.values()] + def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: + if mode == 'json_schema': + return JsonSchemaOutputSchema( + text_processor=self.text_processor, + ) + elif mode == 'prompted_json': + return PromptedJsonOutputSchema( + text_processor=self.text_processor, + ) + elif mode == 'tool': + return ToolOutputSchema(tools=self.tools) + else: + assert_never(mode) + + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return False + + @property + def tools(self) -> dict[str, OutputTool[OutputDataT]]: + """Get the tools for this output schema.""" + # We return tools here as they're checked in Agent._register_tool. + # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time. + return self._tools + + +class TextOutputSchema(OutputSchema[OutputDataT], ABC): + @abstractmethod + async def process( + self, + text: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + raise NotImplementedError() + + +@dataclass +class PlainTextOutputSchema(TextOutputSchema[OutputDataT]): + text_processor: PlainTextOutputProcessor[OutputDataT] | None = None + + @property + def mode(self) -> OutputMode: + return 'text' + + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return True async def process( self, @@ -435,11 +502,41 @@ async def process( Returns: Either the validated output data (left) or a retry message (right). """ - assert self.allow_text_output is not False - - if self.text_output_schema is None: + if self.text_processor is None: return cast(OutputDataT, text) + return await self.text_processor.process( + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + + +@dataclass +class JsonTextOutputSchema(TextOutputSchema[OutputDataT], ABC): + text_processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] + + @property + def object_def(self) -> OutputObjectDefinition: + return self.text_processor.object_def + + async def process( + self, + text: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + """Validate an output message. + + Args: + text: The output text to validate. + run_context: The current run context. + allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. + + Returns: + Either the validated output data (left) or a retry message (right). + """ + def strip_markdown_fences(text: str) -> str: if text.startswith('{'): return text @@ -453,158 +550,133 @@ def strip_markdown_fences(text: str) -> str: text = strip_markdown_fences(text) - return await self.text_output_schema.process( + return await self.text_processor.process( text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) -@dataclass -class OutputObjectDefinition: - json_schema: ObjectJsonSchema - name: str | None = None - description: str | None = None - strict: bool | None = None +class JsonSchemaOutputSchema(JsonTextOutputSchema[OutputDataT]): + @property + def mode(self) -> OutputMode: + return 'json_schema' + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return 'json_schema' in supported_modes + + +class PromptedJsonOutputSchema(JsonTextOutputSchema[OutputDataT]): @property - def instructions(self) -> str: - """Get instructions for model to output manual JSON matching the schema.""" - schema = self.json_schema.copy() - if self.name: - schema['title'] = self.name - if self.description: - schema['description'] = self.description + def mode(self) -> OutputMode: + return 'prompted_json' - # Eventually move DEFAULT_PROMPTED_JSON_PROMPT to ModelProfile so it can be tweaked on a per model basis - return DEFAULT_PROMPTED_JSON_PROMPT.format(schema=json.dumps(schema)) + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return True + def instructions(self, template: str) -> str: + """Get instructions for model to output manual JSON matching the schema.""" + object_def = self.object_def + schema = object_def.json_schema.copy() + if object_def.name: + schema['title'] = object_def.name + if object_def.description: + schema['description'] = object_def.description -@dataclass(init=False) -class OutputUnionDataEntry: - kind: str - data: dict[str, Any] + return template.format(schema=json.dumps(schema)) @dataclass(init=False) -class OutputUnionData: - result: OutputUnionDataEntry +class ToolOutputSchema(OutputSchema[OutputDataT]): + _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + def __init__(self, tools: dict[str, OutputTool[OutputDataT]]): + self._tools = tools -# TODO: Better class naming -@dataclass(init=False) -class OutputUnionSchema(Generic[OutputDataT]): - object_def: OutputObjectDefinition - _root_object_schema: OutputObjectSchema[OutputUnionData] - _object_schemas: dict[str, OutputObjectSchema[OutputDataT]] + @property + def mode(self) -> OutputMode: + return 'tool' - def __init__( - self, - outputs: Sequence[OutputTypeOrFunction[OutputDataT]], - *, - name: str | None = None, - description: str | None = None, - strict: bool | None = None, - ): - json_schemas: list[ObjectJsonSchema] = [] - self._object_schemas = {} - for output in outputs: - object_schema = OutputObjectSchema(output=output, strict=strict) - object_def = object_schema.object_def + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return 'tool' in supported_modes - object_key = object_def.name or output.__name__ - i = 1 - original_key = object_key - while object_key in self._object_schemas: - i += 1 - object_key = f'{original_key}_{i}' + @property + def tools(self) -> dict[str, OutputTool[OutputDataT]]: + """Get the tools for this output schema.""" + return self._tools - self._object_schemas[object_key] = object_schema + def tool_names(self) -> list[str]: + """Return the names of the tools.""" + return list(self.tools.keys()) - json_schema = object_def.json_schema - if object_name := object_def.name: - json_schema['title'] = object_name - if object_description := object_def.description: - json_schema['description'] = object_description + def tool_defs(self) -> list[ToolDefinition]: + """Get tool definitions to register with the model.""" + return [t.tool_def for t in self.tools.values()] - json_schemas.append(json_schema) + def find_named_tool( + self, parts: Iterable[_messages.ModelResponsePart], tool_name: str + ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: + """Find a tool that matches one of the calls, with a specific name.""" + for part in parts: # pragma: no branch + if isinstance(part, _messages.ToolCallPart): # pragma: no branch + if part.tool_name == tool_name: + return part, self.tools[tool_name] - json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas) + def find_tool( + self, + parts: Iterable[_messages.ModelResponsePart], + ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: + """Find a tool that matches one of the calls.""" + for part in parts: + if isinstance(part, _messages.ToolCallPart): # pragma: no branch + if result := self.tools.get(part.tool_name): + yield part, result - discriminated_json_schemas: list[ObjectJsonSchema] = [] - for object_key, json_schema in zip(self._object_schemas.keys(), json_schemas): - title = json_schema.pop('title', None) - description = json_schema.pop('description', None) - discriminated_json_schema = { - 'type': 'object', - 'properties': { - 'kind': { - 'type': 'string', - 'const': object_key, - }, - 'data': json_schema, - }, - 'required': ['kind', 'data'], - 'additionalProperties': False, - } - if title: - discriminated_json_schema['title'] = title - if description: - discriminated_json_schema['description'] = description +@dataclass(init=False) +class ToolOrTextOutputSchema(PlainTextOutputSchema[OutputDataT], ToolOutputSchema[OutputDataT]): + def __init__( + self, + text_processor: PlainTextOutputProcessor[OutputDataT] | None, + tools: dict[str, OutputTool[OutputDataT]], + ): + self.text_processor = text_processor + self._tools = tools - discriminated_json_schemas.append(discriminated_json_schema) + @property + def mode(self) -> OutputMode: + return 'tool_or_text' - self._root_object_schema = OutputObjectSchema(output=OutputUnionData) + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return 'tool' in supported_modes - json_schema = { - 'type': 'object', - 'properties': { - 'result': { - 'anyOf': discriminated_json_schemas, - } - }, - 'required': ['result'], - 'additionalProperties': False, - } - if all_defs: - json_schema['$defs'] = all_defs - self.object_def = OutputObjectDefinition( - json_schema=json_schema, - strict=strict, - name=name, - description=description, - ) +@dataclass +class OutputObjectDefinition: + json_schema: ObjectJsonSchema + name: str | None = None + description: str | None = None + strict: bool | None = None + +@dataclass(init=False) +class BaseOutputProcessor(ABC, Generic[OutputDataT]): + @abstractmethod async def process( self, - data: str | dict[str, Any] | None, + data: str, run_context: RunContext[AgentDepsT], allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: - result = await self._root_object_schema.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors - ) - - result = result.result - kind = result.kind - data = result.data - try: - object_schema = self._object_schemas[kind] - except KeyError as e: - if wrap_validation_errors: - m = _messages.RetryPromptPart(content=f'Invalid kind: {kind}') - raise ToolRetryError(m) from e - else: - raise # pragma: lax no cover - - return await object_schema.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors - ) + """Process an output message, performing validation and (if necessary) calling the output function.""" + raise NotImplementedError() @dataclass(init=False) -class OutputObjectSchema(Generic[OutputDataT]): +class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]): object_def: OutputObjectDefinition outer_typed_dict_key: str | None = None _validator: SchemaValidator @@ -709,8 +781,132 @@ async def process( return output +@dataclass +class UnionOutputResult: + kind: str + data: ObjectJsonSchema + + +@dataclass +class UnionOutputModel: + result: UnionOutputResult + + +@dataclass(init=False) +class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]): + object_def: OutputObjectDefinition + _union_schema: ObjectOutputProcessor[UnionOutputModel] + _object_schemas: dict[str, ObjectOutputProcessor[OutputDataT]] + + def __init__( + self, + outputs: Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ): + self._union_schema = ObjectOutputProcessor(output=UnionOutputModel) + + json_schemas: list[ObjectJsonSchema] = [] + self._object_schemas = {} + for output in outputs: + object_schema = ObjectOutputProcessor(output=output, strict=strict) + object_def = object_schema.object_def + + object_key = object_def.name or output.__name__ + i = 1 + original_key = object_key + while object_key in self._object_schemas: + i += 1 + object_key = f'{original_key}_{i}' + + self._object_schemas[object_key] = object_schema + + json_schema = object_def.json_schema + if object_name := object_def.name: + json_schema['title'] = object_name + if object_description := object_def.description: + json_schema['description'] = object_description + + json_schemas.append(json_schema) + + json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas) + + discriminated_json_schemas: list[ObjectJsonSchema] = [] + for object_key, json_schema in zip(self._object_schemas.keys(), json_schemas): + title = json_schema.pop('title', None) + description = json_schema.pop('description', None) + + discriminated_json_schema = { + 'type': 'object', + 'properties': { + 'kind': { + 'type': 'string', + 'const': object_key, + }, + 'data': json_schema, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + } + if title: + discriminated_json_schema['title'] = title + if description: + discriminated_json_schema['description'] = description + + discriminated_json_schemas.append(discriminated_json_schema) + + json_schema = { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': discriminated_json_schemas, + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + if all_defs: + json_schema['$defs'] = all_defs + + self.object_def = OutputObjectDefinition( + json_schema=json_schema, + strict=strict, + name=name, + description=description, + ) + + async def process( + self, + data: str | dict[str, Any] | None, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + union_object = await self._union_schema.process( + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + + result = union_object.result + kind = result.kind + data = result.data + try: + object_schema = self._object_schemas[kind] + except KeyError as e: + if wrap_validation_errors: + m = _messages.RetryPromptPart(content=f'Invalid kind: {kind}') + raise ToolRetryError(m) from e + else: + raise # pragma: lax no cover + + return await object_schema.process( + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + + @dataclass(init=False) -class OutputFunctionSchema(Generic[OutputDataT]): +class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]): _function_schema: _function_schema.FunctionSchema _str_argument_name: str @@ -758,10 +954,10 @@ async def process( @dataclass(init=False) class OutputTool(Generic[OutputDataT]): - parameters_schema: OutputObjectSchema[OutputDataT] + parameters_schema: ObjectOutputProcessor[OutputDataT] tool_def: ToolDefinition - def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool): + def __init__(self, *, name: str, parameters_schema: ObjectOutputProcessor[OutputDataT], multiple: bool): self.parameters_schema = parameters_schema definition = parameters_schema.object_def diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a4ea05329..b48650adf 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -141,7 +141,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) - _output_schema: _output.OutputSchema[OutputDataT] = dataclasses.field(repr=False) + _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) _instructions: str | None = dataclasses.field(repr=False) _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) @@ -318,8 +318,10 @@ def __init__( warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning) output_retries = result_retries - self._output_schema = _output.OutputSchema[OutputDataT]( + default_output_mode = self.model.profile.default_output_mode if isinstance(self.model, models.Model) else None + self._output_schema = _output.OutputSchema[OutputDataT].build( output_type, + default_mode=default_output_mode, name=self._deprecated_result_tool_name, description=self._deprecated_result_tool_description, ) @@ -674,12 +676,10 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: *[await func.run(run_context) for func in self._instructions_functions], ] - if ( - output_schema.mode == 'prompted_json' - and (output_object_schema := output_schema.text_output_schema) - and (object_def := output_object_schema.object_def) - ): - parts.append(object_def.instructions) + if isinstance(output_schema, _output.PromptedJsonOutputSchema): + template = model_used.profile.prompted_json_output_instructions + instructions = output_schema.instructions(template) + parts.append(instructions) parts = [p for p in parts if p] if not parts: @@ -1004,10 +1004,13 @@ async def stream_to_final( async for maybe_part_event in streamed_response: if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part - if isinstance(new_part, _messages.TextPart): - if output_schema.allow_text_output: - return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart): # pragma: no branch + if isinstance(new_part, _messages.TextPart) and isinstance( + output_schema, _output.TextOutputSchema + ): + return FinalResult(s, None, None) + elif isinstance(new_part, _messages.ToolCallPart) and isinstance( + output_schema, _output.ToolOutputSchema + ): # pragma: no branch for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) return None @@ -1644,17 +1647,16 @@ def _prepare_output_schema( if output_type is not None: if self._output_validators: raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators') - schema = _output.OutputSchema[RunOutputDataT]( + schema = _output.OutputSchema[RunOutputDataT].build( output_type, name=self._deprecated_result_tool_name, description=self._deprecated_result_tool_description, + default_mode=model_profile.default_output_mode, ) else: - schema = self._output_schema + schema = self._output_schema.with_default_mode(model_profile.default_output_mode) - if schema.mode is None: - schema.mode = model_profile.default_output_mode - if not schema.is_mode_supported(model_profile): + if not schema.is_supported(model_profile.output_modes): modes = ', '.join(f"'{m}'" for m in model_profile.output_modes) raise exceptions.UserError(f"Output mode '{schema.mode}' is not among supported modes: {modes}") diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 35b8e8a52..69808aa5f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -213,7 +213,7 @@ async def _messages_create( if not tools: tool_choice = None else: - if model_request_parameters.output_mode == 'tool': + if not model_request_parameters.allow_text_output: tool_choice = {'type': 'any'} else: tool_choice = {'type': 'auto'} diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 3ed16e726..fb8753b43 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -378,7 +378,7 @@ def _map_tool_config(self, model_request_parameters: ModelRequestParameters) -> return None tool_choice: ToolChoiceTypeDef - if model_request_parameters.output_mode == 'tool': + if not model_request_parameters.allow_text_output: tool_choice = {'any': {}} else: tool_choice = {'auto': {}} diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 14c996acb..96c026c32 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -194,7 +194,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: - if model_request_parameters.output_mode == 'tool' and tools: + if not model_request_parameters.allow_text_output and tools: return _tool_config([t['name'] for t in tools['function_declarations']]) else: return None diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index d041d13b4..f5ef30d2c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -215,7 +215,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None ) -> ToolConfigDict | None: - if model_request_parameters.output_mode == 'tool' and tools: + if not model_request_parameters.allow_text_output and tools: names: list[str] = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index 20de2f63e..1b23d2f5f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -209,7 +209,7 @@ async def _completions_create( # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif model_request_parameters.output_mode == 'tool': + elif not model_request_parameters.allow_text_output: tool_choice = 'required' else: tool_choice = 'auto' diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index a5f096817..8c558bca7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -287,7 +287,7 @@ def _get_tool_choice(self, model_request_parameters: ModelRequestParameters) -> """ if not model_request_parameters.function_tools and not model_request_parameters.output_tools: return None - elif model_request_parameters.output_mode == 'tool': + elif not model_request_parameters.allow_text_output: return 'required' else: return 'auto' diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index f46172c92..6856dc032 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -266,7 +266,7 @@ async def _completions_create( tools = self._get_tools(model_request_parameters) if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif model_request_parameters.output_mode == 'tool': + elif not model_request_parameters.allow_text_output: tool_choice = 'required' else: tool_choice = 'auto' @@ -674,7 +674,7 @@ async def _responses_create( if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None - elif model_request_parameters.output_mode == 'tool': + elif not model_request_parameters.allow_text_output: tool_choice = 'required' else: tool_choice = 'auto' diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index aad62f2a9..5f7119bff 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -1,10 +1,13 @@ from __future__ import annotations as _annotations from dataclasses import dataclass, field, fields, replace -from typing import Callable, Literal, Union +from textwrap import dedent +from typing import Callable, Union from typing_extensions import Self +from pydantic_ai._output import StructuredOutputMode, SupportableOutputMode + from ._json_schema import JsonSchemaTransformer @@ -14,11 +17,22 @@ class ModelProfile: json_schema_transformer: type[JsonSchemaTransformer] | None = None """The transformer to use to make JSON schemas for tools and structured output compatible with the model.""" - output_modes: set[Literal['tool', 'json_schema']] = field(default_factory=lambda: {'tool'}) + output_modes: set[SupportableOutputMode] = field(default_factory=lambda: {'tool'}) """The output modes supported by the model. Essentially all models support `tool` mode, but some also support `json_schema` mode, which needs to be specifically implemented on the model class.""" - default_output_mode: Literal['tool', 'json_schema', 'prompted_json'] = 'tool' + default_output_mode: StructuredOutputMode = 'tool' """The default output mode to use for the model.""" + prompted_json_output_instructions: str = dedent( + """ + Always respond with a JSON object that's compatible with this schema: + + {schema} + + Don't include any text or Markdown fencing before or after. + """ + ) + """The instructions to use for prompted JSON output. The schema placeholder will be replaced with the JSON schema for the output.""" + @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: """Build a ModelProfile subclass instance from a ModelProfile instance.""" diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 57174e27e..0f7416062 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -8,7 +8,7 @@ from typing import Generic from pydantic import ValidationError -from typing_extensions import TypeVar, assert_type, deprecated, overload +from typing_extensions import TypeVar, deprecated, overload from . import _utils, exceptions, messages as _messages, models from ._output import ( @@ -18,9 +18,12 @@ OutputSchema, OutputValidator, OutputValidatorFunc, + PlainTextOutputSchema, PromptedJsonOutput, TextOutput, + TextOutputSchema, ToolOutput, + ToolOutputSchema, ) from .messages import AgentStreamEvent, FinalResultEvent from .tools import AgentDepsT, RunContext @@ -92,7 +95,7 @@ async def _validate_response( ) -> OutputDataT: """Validate a structured result message.""" call = None - if output_tool_name is not None: + if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -103,12 +106,16 @@ async def _validate_response( result_data = await output_tool.process( call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) - else: + elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + else: + raise exceptions.UnexpectedModelBehavior( # pragma: no cover + 'Invalid response, unable to process text output' + ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) @@ -131,11 +138,12 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" if isinstance(e, _messages.PartStartEvent): new_part = e.part - if isinstance(new_part, _messages.ToolCallPart): + if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema): for call, _ in output_schema.find_tool([new_part]): # pragma: no branch return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id) - elif output_schema.allow_text_output: # pragma: no branch - assert_type(e, _messages.PartStartEvent) + elif isinstance(new_part, _messages.TextPart) and isinstance( + output_schema, TextOutputSchema + ): # pragma: no branch return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) usage_checking_stream = _get_usage_checking_stream_response( @@ -326,7 +334,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ - if self._output_schema.allow_text_output != 'plain': + if not isinstance(self._output_schema, PlainTextOutputSchema): raise exceptions.UserError('stream_text() can only be used with text responses') if delta: @@ -405,7 +413,7 @@ async def validate_structured_output( ) -> OutputDataT: """Validate a structured result message.""" call = None - if self._output_tool_name is not None: + if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -416,12 +424,16 @@ async def validate_structured_output( result_data = await output_tool.process( call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) - else: + elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + else: + raise exceptions.UnexpectedModelBehavior( # pragma: no cover + 'Invalid response, unable to process text output' + ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover diff --git a/tests/test_agent.py b/tests/test_agent.py index 751b5036e..28e541852 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -13,7 +13,15 @@ from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai._output import JsonSchemaOutput, OutputSpec, PromptedJsonOutput, TextOutput, ToolOutput +from pydantic_ai._output import ( + JsonSchemaOutput, + OutputSpec, + PromptedJsonOutput, + TextOutput, + TextOutputSchema, + ToolOutput, + ToolOutputSchema, +) from pydantic_ai.agent import AgentRunResult from pydantic_ai.messages import ( BinaryContent, @@ -355,7 +363,7 @@ def test_response_tuple(): m = TestModel() agent = Agent(m, output_type=tuple[str, str]) - assert agent._output_schema.allow_text_output == 'json' # pyright: ignore[reportPrivateUsage] + assert isinstance(agent._output_schema, ToolOutputSchema) # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert result.output == snapshot(('a', 'a')) @@ -430,7 +438,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: got_tool_call_name = ctx.tool_name return o - assert agent._output_schema.allow_text_output == 'plain' # pyright: ignore[reportPrivateUsage] + assert isinstance(agent._output_schema, TextOutputSchema) # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert isinstance(result.output, str) From e70d24905825b9de96e8adf97833d9c50b08f978 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 12 Jun 2025 14:11:29 +0000 Subject: [PATCH 23/90] Clean up some variable names --- pydantic_ai_slim/pydantic_ai/_output.py | 126 ++++++++++-------------- pydantic_ai_slim/pydantic_ai/_utils.py | 13 +++ 2 files changed, 67 insertions(+), 72 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 471bfa7d4..f2e54f6c9 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -2,7 +2,6 @@ import inspect import json -import re from abc import ABC, abstractmethod from collections.abc import Awaitable, Iterable, Iterator, Sequence from dataclasses import dataclass, field @@ -116,7 +115,7 @@ def __init__(self, tool_retry: _messages.RetryPromptPart): class ToolOutput(Generic[OutputDataT]): """Marker class to use tools for outputs, and customize the tool.""" - output_type: OutputTypeOrFunction[OutputDataT] + output: OutputTypeOrFunction[OutputDataT] name: str | None description: str | None max_retries: int | None @@ -131,7 +130,7 @@ def __init__( max_retries: int | None = None, strict: bool | None = None, ): - self.output_type = type_ + self.output = type_ self.name = name self.description = description self.max_retries = max_retries @@ -288,7 +287,7 @@ def build( if isinstance(output_spec, JsonSchemaOutput): return JsonSchemaOutputSchema( - text_processor=cls._build_text_processor( + cls._build_processor( output_spec.outputs, name=output_spec.name, description=output_spec.description, @@ -298,9 +297,7 @@ def build( if isinstance(output_spec, PromptedJsonOutput): return PromptedJsonOutputSchema( - text_processor=cls._build_text_processor( - output_spec.outputs, name=output_spec.name, description=output_spec.description - ), + cls._build_processor(output_spec.outputs, name=output_spec.name, description=output_spec.description), ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] @@ -328,18 +325,16 @@ def build( text_output_schema = PlainTextOutputProcessor(text_output.output_function) if len(tools) == 0: - return PlainTextOutputSchema(text_processor=text_output_schema) + return PlainTextOutputSchema(text_output_schema) else: - return ToolOrTextOutputSchema(text_processor=text_output_schema, tools=tools) + return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools) if len(tool_outputs) > 0: - return ToolOutputSchema(tools=tools) + return ToolOutputSchema(tools) if len(other_outputs) > 0: schema = OutputSchemaWithoutMode( - text_processor=cls._build_text_processor( - other_outputs, name=name, description=description, strict=strict - ), + processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), tools=tools, ) if default_mode: @@ -367,18 +362,17 @@ def _build_tools( description = None strict = None if isinstance(output, ToolOutput): - output_type = output.output_type # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads name = output.name description = output.description strict = output.strict - else: - output_type = output + + output = output.output if name is None: name = default_name if multiple: - name += f'_{output_type.__name__}' + name += f'_{output.__name__}' i = 1 original_name = name @@ -390,13 +384,13 @@ def _build_tools( if strict is None: strict = default_strict - parameters_schema = ObjectOutputProcessor(output=output_type, description=description, strict=strict) - tools[name] = OutputTool(name=name, parameters_schema=parameters_schema, multiple=multiple) + processor = ObjectOutputProcessor(output=output, description=description, strict=strict) + tools[name] = OutputTool(name=name, processor=processor, multiple=multiple) return tools @staticmethod - def _build_text_processor( + def _build_processor( outputs: Sequence[OutputTypeOrFunction[OutputDataT]], name: str | None = None, description: str | None = None, @@ -419,15 +413,15 @@ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDa @dataclass(init=False) class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): - text_processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] + processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) def __init__( self, - text_processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], + processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], tools: dict[str, OutputTool[OutputDataT]], ): - self.text_processor = text_processor + self.processor = processor self._tools = tools @property @@ -437,11 +431,11 @@ def mode(self) -> None: def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: if mode == 'json_schema': return JsonSchemaOutputSchema( - text_processor=self.text_processor, + self.processor, ) elif mode == 'prompted_json': return PromptedJsonOutputSchema( - text_processor=self.text_processor, + self.processor, ) elif mode == 'tool': return ToolOutputSchema(tools=self.tools) @@ -474,7 +468,7 @@ async def process( @dataclass class PlainTextOutputSchema(TextOutputSchema[OutputDataT]): - text_processor: PlainTextOutputProcessor[OutputDataT] | None = None + processor: PlainTextOutputProcessor[OutputDataT] | None = None @property def mode(self) -> OutputMode: @@ -502,21 +496,21 @@ async def process( Returns: Either the validated output data (left) or a retry message (right). """ - if self.text_processor is None: + if self.processor is None: return cast(OutputDataT, text) - return await self.text_processor.process( + return await self.processor.process( text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @dataclass class JsonTextOutputSchema(TextOutputSchema[OutputDataT], ABC): - text_processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] + processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] @property def object_def(self) -> OutputObjectDefinition: - return self.text_processor.object_def + return self.processor.object_def async def process( self, @@ -536,21 +530,9 @@ async def process( Returns: Either the validated output data (left) or a retry message (right). """ + text = _utils.strip_markdown_fences(text) - def strip_markdown_fences(text: str) -> str: - if text.startswith('{'): - return text - - regex = r'```(?:\w+)?\n(\{.*\})\n```' - match = re.search(regex, text, re.DOTALL) - if match: - return match.group(1) - - return text - - text = strip_markdown_fences(text) - - return await self.text_processor.process( + return await self.processor.process( text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -638,10 +620,10 @@ def find_tool( class ToolOrTextOutputSchema(PlainTextOutputSchema[OutputDataT], ToolOutputSchema[OutputDataT]): def __init__( self, - text_processor: PlainTextOutputProcessor[OutputDataT] | None, + processor: PlainTextOutputProcessor[OutputDataT] | None, tools: dict[str, OutputTool[OutputDataT]], ): - self.text_processor = text_processor + self.processor = processor self._tools = tools @property @@ -795,8 +777,8 @@ class UnionOutputModel: @dataclass(init=False) class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]): object_def: OutputObjectDefinition - _union_schema: ObjectOutputProcessor[UnionOutputModel] - _object_schemas: dict[str, ObjectOutputProcessor[OutputDataT]] + _union_processor: ObjectOutputProcessor[UnionOutputModel] + _processors: dict[str, ObjectOutputProcessor[OutputDataT]] def __init__( self, @@ -806,35 +788,35 @@ def __init__( description: str | None = None, strict: bool | None = None, ): - self._union_schema = ObjectOutputProcessor(output=UnionOutputModel) + self._union_processor = ObjectOutputProcessor(output=UnionOutputModel) json_schemas: list[ObjectJsonSchema] = [] - self._object_schemas = {} + self._processors = {} for output in outputs: - object_schema = ObjectOutputProcessor(output=output, strict=strict) - object_def = object_schema.object_def + processor = ObjectOutputProcessor(output=output, strict=strict) + object_def = processor.object_def object_key = object_def.name or output.__name__ i = 1 original_key = object_key - while object_key in self._object_schemas: + while object_key in self._processors: i += 1 object_key = f'{original_key}_{i}' - self._object_schemas[object_key] = object_schema + self._processors[object_key] = processor json_schema = object_def.json_schema - if object_name := object_def.name: - json_schema['title'] = object_name - if object_description := object_def.description: - json_schema['description'] = object_description + if object_def.name: + json_schema['title'] = object_def.name + if object_def.description: + json_schema['description'] = object_def.description json_schemas.append(json_schema) json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas) discriminated_json_schemas: list[ObjectJsonSchema] = [] - for object_key, json_schema in zip(self._object_schemas.keys(), json_schemas): + for object_key, json_schema in zip(self._processors.keys(), json_schemas): title = json_schema.pop('title', None) description = json_schema.pop('description', None) @@ -884,7 +866,7 @@ async def process( allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: - union_object = await self._union_schema.process( + union_object = await self._union_processor.process( data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -892,7 +874,7 @@ async def process( kind = result.kind data = result.data try: - object_schema = self._object_schemas[kind] + processor = self._processors[kind] except KeyError as e: if wrap_validation_errors: m = _messages.RetryPromptPart(content=f'Invalid kind: {kind}') @@ -900,7 +882,7 @@ async def process( else: raise # pragma: lax no cover - return await object_schema.process( + return await processor.process( data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -954,25 +936,25 @@ async def process( @dataclass(init=False) class OutputTool(Generic[OutputDataT]): - parameters_schema: ObjectOutputProcessor[OutputDataT] + processor: ObjectOutputProcessor[OutputDataT] tool_def: ToolDefinition - def __init__(self, *, name: str, parameters_schema: ObjectOutputProcessor[OutputDataT], multiple: bool): - self.parameters_schema = parameters_schema - definition = parameters_schema.object_def + def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool): + self.processor = processor + object_def = processor.object_def - description = definition.description + description = object_def.description if not description: description = DEFAULT_OUTPUT_TOOL_DESCRIPTION if multiple: - description = f'{definition.name}: {description}' + description = f'{object_def.name}: {description}' self.tool_def = ToolDefinition( name=name, description=description, - parameters_json_schema=definition.json_schema, - strict=definition.strict, - outer_typed_dict_key=parameters_schema.outer_typed_dict_key, + parameters_json_schema=object_def.json_schema, + strict=object_def.strict, + outer_typed_dict_key=processor.outer_typed_dict_key, ) async def process( @@ -994,7 +976,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - output = await self.parameters_schema.process( + output = await self.processor.process( tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False ) except ValidationError as e: diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index a280115f8..8e0d97165 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import asyncio +import re import time import uuid from collections.abc import AsyncIterable, AsyncIterator, Iterator @@ -381,3 +382,15 @@ def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str rewritten_schemas.append(schema) return rewritten_schemas, all_defs + + +def strip_markdown_fences(text: str) -> str: + if text.startswith('{'): + return text + + regex = r'```(?:\w+)?\n(\{.*\})\n```' + match = re.search(regex, text, re.DOTALL) + if match: + return match.group(1) + + return text From 4592b0b4dcfe8540159d0463e06084ba6c0b2a71 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 12 Jun 2025 14:56:55 +0000 Subject: [PATCH 24/90] Improve test coverage --- pydantic_ai_slim/pydantic_ai/_output.py | 27 +++-- pydantic_ai_slim/pydantic_ai/_utils.py | 2 +- tests/test_utils.py | 133 +++++++++++++++++++++++- 3 files changed, 146 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index f2e54f6c9..37787bc9b 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -426,7 +426,7 @@ def __init__( @property def mode(self) -> None: - return None + return None # pragma: no cover def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: if mode == 'json_schema': @@ -444,7 +444,7 @@ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDa def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: """Whether the mode is supported by the model.""" - return False + return False # pragma: no cover @property def tools(self) -> dict[str, OutputTool[OutputDataT]]: @@ -806,7 +806,7 @@ def __init__( self._processors[object_key] = processor json_schema = object_def.json_schema - if object_def.name: + if object_def.name: # pragma: no branch json_schema['title'] = object_def.name if object_def.description: json_schema['description'] = object_def.description @@ -832,7 +832,7 @@ def __init__( 'required': ['kind', 'data'], 'additionalProperties': False, } - if title: + if title: # pragma: no branch discriminated_json_schema['title'] = title if description: discriminated_json_schema['description'] = description @@ -875,12 +875,12 @@ async def process( data = result.data try: processor = self._processors[kind] - except KeyError as e: + except KeyError as e: # pragma: no cover if wrap_validation_errors: m = _messages.RetryPromptPart(content=f'Invalid kind: {kind}') raise ToolRetryError(m) from e else: - raise # pragma: lax no cover + raise return await processor.process( data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors @@ -896,20 +896,19 @@ def __init__( self, output_function: TextOutputFunction[OutputDataT], ): - if inspect.isfunction(output_function) or inspect.ismethod(output_function): - self._function_schema = _function_schema.function_schema(output_function, GenerateToolJsonSchema) + self._function_schema = _function_schema.function_schema(output_function, GenerateToolJsonSchema) - arguments_schema = self._function_schema.json_schema.get('properties', {}) - argument_name = next(iter(arguments_schema.keys()), None) - if argument_name and arguments_schema.get(argument_name, {}).get('type') == 'string': - self._str_argument_name = argument_name - return + arguments_schema = self._function_schema.json_schema.get('properties', {}) + argument_name = next(iter(arguments_schema.keys()), None) + if argument_name and arguments_schema.get(argument_name, {}).get('type') == 'string': + self._str_argument_name = argument_name + return raise UserError('TextOutput must take a function taking a `str`') @property def object_def(self) -> None: - return None + return None # pragma: no cover async def process( self, diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 8e0d97165..d4e770601 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -309,7 +309,7 @@ def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, s """Update $refs in a schema to use the new names from name_mapping.""" if '$ref' in s: ref = s['$ref'] - if ref.startswith('#/$defs/'): + if ref.startswith('#/$defs/'): # pragma: no branch original_name = ref[8:] # Remove '#/$defs/' new_name = name_mapping.get(original_name, original_name) s['$ref'] = f'#/$defs/{new_name}' diff --git a/tests/test_utils.py b/tests/test_utils.py index fdd042b46..afd0b7fa9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,6 +17,7 @@ group_by_temporal, merge_json_schema_defs, run_in_executor, + strip_markdown_fences, ) from .models.mock_async_stream import MockAsyncStream @@ -262,7 +263,67 @@ def test_merge_json_schema_defs(): 'type': 'object', } - schemas = [foo_bar_schema, foo_bar_baz_schema, no_title_schema, another_no_title_schema] + # Schema with nested properties, array items, prefixItems, and anyOf/oneOf + complex_schema = { + '$defs': { + 'Nested': { + 'description': 'A nested type', + 'properties': {'nested': {'type': 'string'}}, + 'required': ['nested'], + 'title': 'Nested', + 'type': 'object', + }, + 'ArrayItem': { + 'description': 'An array item type', + 'properties': {'item': {'type': 'string'}}, + 'required': ['item'], + 'title': 'ArrayItem', + 'type': 'object', + }, + 'UnionType': { + 'description': 'A union type', + 'properties': {'union': {'type': 'string'}}, + 'required': ['union'], + 'title': 'UnionType', + 'type': 'object', + }, + }, + 'properties': { + 'nested_props': { + 'type': 'object', + 'properties': { + 'deep_nested': {'$ref': '#/$defs/Nested'}, + }, + }, + 'array_with_items': { + 'type': 'array', + 'items': {'$ref': '#/$defs/ArrayItem'}, + }, + 'array_with_prefix': { + 'type': 'array', + 'prefixItems': [ + {'$ref': '#/$defs/ArrayItem'}, + {'$ref': '#/$defs/Nested'}, + ], + }, + 'union_anyOf': { + 'anyOf': [ + {'$ref': '#/$defs/UnionType'}, + {'$ref': '#/$defs/Nested'}, + ], + }, + 'union_oneOf': { + 'oneOf': [ + {'$ref': '#/$defs/UnionType'}, + {'$ref': '#/$defs/ArrayItem'}, + ], + }, + }, + 'type': 'object', + 'title': 'ComplexSchema', + } + + schemas = [foo_bar_schema, foo_bar_baz_schema, no_title_schema, another_no_title_schema, complex_schema] rewritten_schemas, all_defs = merge_json_schema_defs(schemas) assert all_defs == snapshot( { @@ -322,6 +383,27 @@ def test_merge_json_schema_defs(): 'title': 'Bar', 'type': 'object', }, + 'Nested': { + 'description': 'A nested type', + 'properties': {'nested': {'type': 'string'}}, + 'required': ['nested'], + 'title': 'Nested', + 'type': 'object', + }, + 'ArrayItem': { + 'description': 'An array item type', + 'properties': {'item': {'type': 'string'}}, + 'required': ['item'], + 'title': 'ArrayItem', + 'type': 'object', + }, + 'UnionType': { + 'description': 'A union type', + 'properties': {'union': {'type': 'string'}}, + 'required': ['union'], + 'title': 'UnionType', + 'type': 'object', + }, } ) assert rewritten_schemas == snapshot( @@ -352,5 +434,54 @@ def test_merge_json_schema_defs(): 'required': ['foo', 'bar'], 'type': 'object', }, + { + 'properties': { + 'nested_props': { + 'type': 'object', + 'properties': { + 'deep_nested': {'$ref': '#/$defs/Nested'}, + }, + }, + 'array_with_items': { + 'type': 'array', + 'items': {'$ref': '#/$defs/ArrayItem'}, + }, + 'array_with_prefix': { + 'type': 'array', + 'prefixItems': [ + {'$ref': '#/$defs/ArrayItem'}, + {'$ref': '#/$defs/Nested'}, + ], + }, + 'union_anyOf': { + 'anyOf': [ + {'$ref': '#/$defs/UnionType'}, + {'$ref': '#/$defs/Nested'}, + ], + }, + 'union_oneOf': { + 'oneOf': [ + {'$ref': '#/$defs/UnionType'}, + {'$ref': '#/$defs/ArrayItem'}, + ], + }, + }, + 'type': 'object', + 'title': 'ComplexSchema', + }, ] ) + + +def test_strip_markdown_fences(): + assert strip_markdown_fences('{"foo": "bar"}') == '{"foo": "bar"}' + assert strip_markdown_fences('```json\n{"foo": "bar"}\n```') == '{"foo": "bar"}' + assert ( + strip_markdown_fences('{"foo": "```json\\n{"foo": "bar"}\\n```"}') + == '{"foo": "```json\\n{"foo": "bar"}\\n```"}' + ) + assert ( + strip_markdown_fences('Here is some beautiful JSON:\n\n```\n{"foo": "bar"}\n``` Nice right?') + == '{"foo": "bar"}' + ) + assert strip_markdown_fences('No JSON to be found') == 'No JSON to be found' From f57d078160de4c9fe8bcd1a2cc1850cc4a2fde8f Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 19:23:05 +0000 Subject: [PATCH 25/90] Combine JsonSchemaOutput and PromptedJsonOutput into StructuredTextOutput --- pydantic_ai_slim/pydantic_ai/__init__.py | 5 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 19 +- pydantic_ai_slim/pydantic_ai/_output.py | 167 +++---- pydantic_ai_slim/pydantic_ai/agent.py | 20 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 22 +- pydantic_ai_slim/pydantic_ai/models/google.py | 21 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 44 +- .../pydantic_ai/profiles/__init__.py | 25 +- .../pydantic_ai/profiles/google.py | 6 +- .../pydantic_ai/profiles/openai.py | 14 +- pydantic_ai_slim/pydantic_ai/result.py | 6 +- .../test_anthropic_prompted_json_output.yaml | 161 ------- ...thropic_prompted_json_output_multiple.yaml | 66 --- .../test_gemini_json_schema_output.yaml | 79 ---- ...st_gemini_json_schema_output_multiple.yaml | 120 ----- .../test_gemini_prompted_json_output.yaml | 74 --- ..._gemini_prompted_json_output_multiple.yaml | 73 --- ...emini_prompted_json_output_with_tools.yaml | 157 ------- .../test_google_json_schema_output.yaml | 86 ---- ...st_google_json_schema_output_multiple.yaml | 138 ------ .../test_google_prompted_json_output.yaml | 78 --- ..._google_prompted_json_output_multiple.yaml | 77 --- ...oogle_prompted_json_output_with_tools.yaml | 164 ------- .../test_openai_json_schema_output.yaml | 223 --------- ...st_openai_json_schema_output_multiple.yaml | 293 ------------ .../test_openai_prompted_json_output.yaml | 209 --------- ..._openai_prompted_json_output_multiple.yaml | 209 --------- .../test_json_schema_output.yaml | 288 ------------ .../test_json_schema_output_multiple.yaml | 444 ------------------ .../test_prompted_json_output.yaml | 248 ---------- .../test_prompted_json_output_multiple.yaml | 248 ---------- tests/models/test_anthropic.py | 10 +- tests/models/test_gemini.py | 30 +- tests/models/test_google.py | 32 +- tests/models/test_openai.py | 22 +- tests/models/test_openai_responses.py | 30 +- tests/test_agent.py | 27 +- tests/test_streaming.py | 6 +- 38 files changed, 254 insertions(+), 3687 deletions(-) delete mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml delete mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_json_schema_output.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_prompted_json_output.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml delete mode 100644 tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml delete mode 100644 tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml delete mode 100644 tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml delete mode 100644 tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml delete mode 100644 tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 43d985dc4..5e3c9aaa6 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -12,7 +12,7 @@ ) from .format_prompt import format_as_xml from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl -from .result import JsonSchemaOutput, PromptedJsonOutput, ToolOutput +from .result import StructuredTextOutput, ToolOutput from .tools import RunContext, Tool __all__ = ( @@ -43,8 +43,7 @@ 'RunContext', # result 'ToolOutput', - 'JsonSchemaOutput', - 'PromptedJsonOutput', + 'StructuredTextOutput', # format_prompt 'format_as_xml', ) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 82cf16df7..ad7fb0b34 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -270,12 +270,25 @@ async def add_mcp_server_tools(server: MCPServer) -> None: function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] output_schema = ctx.deps.output_schema + model_profile = ctx.deps.model.profile + + output_tools = [] + output_object = None + if isinstance(output_schema, _output.ToolOutputSchema): + output_tools = output_schema.tool_defs() + elif isinstance(output_schema, _output.StructuredTextOutputSchema): + if not output_schema.use_instructions(model_profile): + output_object = output_schema.object_def + + # Both ToolOrTextOutputSchema and StructuredTextOutputSchema inherit from TextOutputSchema + allow_text_output = isinstance(output_schema, _output.TextOutputSchema) + return models.ModelRequestParameters( function_tools=function_tool_defs, output_mode=output_schema.mode, - output_object=output_schema.object_def if isinstance(output_schema, _output.JsonTextOutputSchema) else None, - output_tools=output_schema.tool_defs() if isinstance(output_schema, _output.ToolOutputSchema) else [], - allow_text_output=isinstance(output_schema, _output.TextOutputSchema), + output_tools=output_tools, + output_object=output_object, + allow_text_output=allow_text_output, ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 8b0723e86..d71707391 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable, Iterable, Iterator, Sequence from dataclasses import dataclass, field -from typing import Any, Callable, Generic, Literal, Union, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload from pydantic import TypeAdapter, ValidationError from pydantic_core import SchemaValidator @@ -17,6 +17,9 @@ from .exceptions import ModelRetry, UserError from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition +if TYPE_CHECKING: + from .profiles import ModelProfile + T = TypeVar('T') """An invariant TypeVar.""" OutputDataT_inv = TypeVar('OutputDataT_inv', default=str) @@ -145,10 +148,18 @@ class TextOutput(Generic[OutputDataT]): @dataclass(init=False) -class JsonSchemaOutput(Generic[OutputDataT]): - """Marker class to use JSON schema output for outputs.""" +class StructuredTextOutput(Generic[OutputDataT]): + """Marker class to use structured text output for outputs.""" outputs: Sequence[OutputTypeOrFunction[OutputDataT]] + instructions: bool | str | None + """Whether to use the model's built-in functionality for structured output matching a JSON schema, or to pass the JSON schema to the model as instructions. + + If `None`, we'll use the model's built-in functionality if it's supported, and otherwise pass the JSON schema to the model as instructions. + If `True`, we'll pass the JSON schema to the model using the instructions template specified on the model's profile. + If `False`, we'll use the model's built-in functionality and raise an error if it's not supported. + If `str`, we'll pass the JSON schema to the model using the specified instructions template. + """ name: str | None description: str | None strict: bool | None @@ -160,30 +171,13 @@ def __init__( name: str | None = None, description: str | None = None, strict: bool | None = True, + instructions: bool | str | None = None, ): self.outputs = flatten_output_spec(type_) self.name = name self.description = description self.strict = strict - - -class PromptedJsonOutput(Generic[OutputDataT]): - """Marker class to use prompted JSON mode for outputs.""" - - outputs: Sequence[OutputTypeOrFunction[OutputDataT]] - name: str | None - description: str | None - - def __init__( - self, - type_: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], - *, - name: str | None = None, - description: str | None = None, - ): - self.outputs = flatten_output_spec(type_) - self.name = name - self.description = description + self.instructions = instructions T_co = TypeVar('T_co', covariant=True) @@ -197,9 +191,8 @@ def __init__( OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co], + StructuredTextOutput[T_co], Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], - JsonSchemaOutput[T_co], - PromptedJsonOutput[T_co], ], type_params=(T_co,), ) @@ -213,30 +206,17 @@ def __init__( type_params=(T_co,), ) - -OutputMode = Literal['text', 'tool', 'json_schema', 'prompted_json', 'tool_or_text'] +OutputMode = Literal['text', 'tool', 'structured_text', 'tool_or_text'] """All output modes.""" -SupportableOutputMode = Literal['tool', 'json_schema'] -"""Output modes that require specific support by a model (class). Used by ModelProfile.output_modes""" -StructuredOutputMode = Literal['tool', 'json_schema', 'prompted_json'] -"""Output modes that can be used for any structured output. Used by ModelProfile.default_output_mode""" +StructuredOutputMode = Literal['tool', 'structured_text'] +"""Output modes that can be used for structured output. Used by ModelProfile.default_structured_output_mode""" class BaseOutputSchema(ABC, Generic[OutputDataT]): - @property - @abstractmethod - def mode(self) -> OutputMode | None: - raise NotImplementedError() - @abstractmethod def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: raise NotImplementedError() - @abstractmethod - def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: - """Whether the mode is supported by the model.""" - raise NotImplementedError() - @property def tools(self) -> dict[str, OutputTool[OutputDataT]]: """Get the tools for this output schema.""" @@ -269,7 +249,7 @@ def build( name: str | None = None, description: str | None = None, strict: bool | None = None, - ) -> OutputSchemaWithoutMode[OutputDataT]: ... + ) -> BaseOutputSchema[OutputDataT]: ... @classmethod def build( @@ -285,19 +265,15 @@ def build( if output_spec is str: return PlainTextOutputSchema() - if isinstance(output_spec, JsonSchemaOutput): - return JsonSchemaOutputSchema( + if isinstance(output_spec, StructuredTextOutput): + return StructuredTextOutputSchema( cls._build_processor( output_spec.outputs, name=output_spec.name, description=output_spec.description, strict=output_spec.strict, ), - ) - - if isinstance(output_spec, PromptedJsonOutput): - return PromptedJsonOutputSchema( - cls._build_processor(output_spec.outputs, name=output_spec.name, description=output_spec.description), + instructions=output_spec.instructions, ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] @@ -407,6 +383,11 @@ def _build_processor( def mode(self) -> OutputMode: raise NotImplementedError() + @abstractmethod + def raise_if_unsupported(self, profile: ModelProfile) -> None: + """Raise an error if the mode is not supported by the model.""" + raise NotImplementedError() + def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: return self @@ -424,28 +405,14 @@ def __init__( self.processor = processor self._tools = tools - @property - def mode(self) -> None: - return None # pragma: no cover - def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: - if mode == 'json_schema': - return JsonSchemaOutputSchema( - self.processor, - ) - elif mode == 'prompted_json': - return PromptedJsonOutputSchema( - self.processor, - ) + if mode == 'structured_text': + return StructuredTextOutputSchema(self.processor) elif mode == 'tool': - return ToolOutputSchema(tools=self.tools) + return ToolOutputSchema(self.tools) else: assert_never(mode) - def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: - """Whether the mode is supported by the model.""" - return False # pragma: no cover - @property def tools(self) -> dict[str, OutputTool[OutputDataT]]: """Get the tools for this output schema.""" @@ -474,9 +441,9 @@ class PlainTextOutputSchema(TextOutputSchema[OutputDataT]): def mode(self) -> OutputMode: return 'text' - def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: - """Whether the mode is supported by the model.""" - return True + def raise_if_unsupported(self, profile: ModelProfile) -> None: + """Raise an error if the mode is not supported by the model.""" + pass async def process( self, @@ -504,9 +471,18 @@ async def process( ) -@dataclass -class JsonTextOutputSchema(TextOutputSchema[OutputDataT], ABC): +@dataclass(init=False) +class StructuredTextOutputSchema(TextOutputSchema[OutputDataT]): processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] + _instructions: bool | str | None = None + + def __init__( + self, + processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], + instructions: bool | str | None = None, + ): + self.processor = processor + self._instructions = instructions @property def object_def(self) -> OutputObjectDefinition: @@ -536,28 +512,31 @@ async def process( text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) - -class JsonSchemaOutputSchema(JsonTextOutputSchema[OutputDataT]): @property def mode(self) -> OutputMode: - return 'json_schema' - - def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: - """Whether the mode is supported by the model.""" - return 'json_schema' in supported_modes - + return 'structured_text' + + def raise_if_unsupported(self, profile: ModelProfile) -> None: + """Raise an error if the mode is not supported by the model.""" + if self._instructions is False and not profile.supports_json_schema_response_format: + raise UserError('Structured output without using instructions is not supported by the model.') + + def use_instructions(self, profile: ModelProfile) -> bool: + if isinstance(self._instructions, bool): + return self._instructions + elif isinstance(self._instructions, str): + return True + else: + return not profile.supports_json_schema_response_format -class PromptedJsonOutputSchema(JsonTextOutputSchema[OutputDataT]): - @property - def mode(self) -> OutputMode: - return 'prompted_json' + def instructions(self, template: str) -> str: + """Get instructions to tell model to output JSON matching the schema.""" + if isinstance(self._instructions, str): + template = self._instructions - def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: - """Whether the mode is supported by the model.""" - return True + if '{schema}' not in template: + raise UserError("Structured output instructions template must contain a '{schema}' placeholder.") - def instructions(self, template: str) -> str: - """Get instructions for model to output manual JSON matching the schema.""" object_def = self.object_def schema = object_def.json_schema.copy() if object_def.name: @@ -579,9 +558,10 @@ def __init__(self, tools: dict[str, OutputTool[OutputDataT]]): def mode(self) -> OutputMode: return 'tool' - def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: - """Whether the mode is supported by the model.""" - return 'tool' in supported_modes + def raise_if_unsupported(self, profile: ModelProfile) -> None: + """Raise an error if the mode is not supported by the model.""" + if not profile.supports_tools: + raise UserError('Output tools are not supported by the model.') @property def tools(self) -> dict[str, OutputTool[OutputDataT]]: @@ -630,9 +610,10 @@ def __init__( def mode(self) -> OutputMode: return 'tool_or_text' - def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: - """Whether the mode is supported by the model.""" - return 'tool' in supported_modes + def raise_if_unsupported(self, profile: ModelProfile) -> None: + """Raise an error if the mode is not supported by the model.""" + if not profile.supports_tools: + raise UserError('Output tools are not supported by the model.') @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b48650adf..7cc45e4cb 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -318,7 +318,9 @@ def __init__( warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning) output_retries = result_retries - default_output_mode = self.model.profile.default_output_mode if isinstance(self.model, models.Model) else None + default_output_mode = ( + self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None + ) self._output_schema = _output.OutputSchema[OutputDataT].build( output_type, default_mode=default_output_mode, @@ -676,9 +678,11 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: *[await func.run(run_context) for func in self._instructions_functions], ] - if isinstance(output_schema, _output.PromptedJsonOutputSchema): - template = model_used.profile.prompted_json_output_instructions - instructions = output_schema.instructions(template) + model_profile = model_used.profile + if isinstance(output_schema, _output.StructuredTextOutputSchema) and output_schema.use_instructions( + model_profile + ): + instructions = output_schema.instructions(model_profile.structured_output_instructions_template) parts.append(instructions) parts = [p for p in parts if p] @@ -1651,14 +1655,12 @@ def _prepare_output_schema( output_type, name=self._deprecated_result_tool_name, description=self._deprecated_result_tool_description, - default_mode=model_profile.default_output_mode, + default_mode=model_profile.default_structured_output_mode, ) else: - schema = self._output_schema.with_default_mode(model_profile.default_output_mode) + schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode) - if not schema.is_supported(model_profile.output_modes): - modes = ', '.join(f"'{m}'" for m in model_profile.output_modes) - raise exceptions.UserError(f"Output mode '{schema.mode}' is not among supported modes: {modes}") + schema.raise_if_unsupported(model_profile) return schema # pyright: ignore[reportReturnType] diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index f8ccfc22f..c4c28b37b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -218,19 +218,15 @@ async def _make_request( request_data['toolConfig'] = tool_config generation_config = _settings_to_generation_config(model_settings) - - output_mode = model_request_parameters.output_mode - if output_mode == 'json_schema': - generation_config['response_mime_type'] = 'application/json' - - output_object = model_request_parameters.output_object - assert output_object is not None - generation_config['response_schema'] = self._map_response_schema(output_object) - - if tools: - raise UserError('Google does not support JSON schema output and tools at the same time.') - elif output_mode == 'prompted_json' and not tools: - generation_config['response_mime_type'] = 'application/json' + if model_request_parameters.output_mode == 'structured_text': + if output_object := model_request_parameters.output_object: + if tools: + raise UserError('Google does not support JSON schema output and tools at the same time.') + + generation_config['response_mime_type'] = 'application/json' + generation_config['response_schema'] = self._map_response_schema(output_object) + elif not tools: + generation_config['response_mime_type'] = 'application/json' if generation_config: request_data['generationConfig'] = generation_config diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 16bf5e5f9..5eeedd703 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -250,20 +250,17 @@ async def _generate_content( ) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]: tools = self._get_tools(model_request_parameters) - output_mode = model_request_parameters.output_mode response_mime_type = None response_schema = None - if output_mode == 'json_schema': - response_mime_type = 'application/json' - - output_object = model_request_parameters.output_object - assert output_object is not None - response_schema = self._map_response_schema(output_object) - - if tools: - raise UserError('Google does not support JSON schema output and tools at the same time/') - elif output_mode == 'prompted_json' and not tools: - response_mime_type = 'application/json' + if model_request_parameters.output_mode == 'structured_text': + if output_object := model_request_parameters.output_object: + if tools: + raise UserError('Google does not support JSON schema output and tools at the same time.') + + response_mime_type = 'application/json' + response_schema = self._map_response_schema(output_object) + elif not tools: + response_mime_type = 'application/json' tool_config = self._get_tool_config(model_request_parameters, tools) system_instruction, contents = await self._map_messages(messages) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 244b2aeb9..eace317b9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -281,16 +281,11 @@ async def _completions_create( openai_messages = await self._map_messages(messages) response_format: chat.completion_create_params.ResponseFormat | None = None - output_mode = model_request_parameters.output_mode - if output_mode == 'json_schema': - output_object = model_request_parameters.output_object - assert output_object is not None - response_format = self._map_json_schema(output_object) - elif ( - output_mode == 'prompted_json' - and OpenAIModelProfile.from_profile(self.profile).openai_supports_json_object_response_format - ): - response_format = {'type': 'json_object'} + if model_request_parameters.output_mode == 'structured_text': + if output_object := model_request_parameters.output_object: + response_format = self._map_json_schema(output_object) + elif self.profile.supports_json_object_response_format: + response_format = {'type': 'json_object'} sampling_settings = ( model_settings @@ -703,23 +698,18 @@ async def _responses_create( reasoning = self._get_reasoning(model_settings) text: responses.ResponseTextConfigParam | None = None - output_mode = model_request_parameters.output_mode - if output_mode == 'json_schema': - output_object = model_request_parameters.output_object - assert output_object is not None - text = {'format': self._map_json_schema(output_object)} - elif ( - output_mode == 'prompted_json' - and OpenAIModelProfile.from_profile(self.profile).openai_supports_json_object_response_format - ): - text = {'format': {'type': 'json_object'}} - - # Without this trick, we'd hit this error: - # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. - # Apparently they're only checking input messages for "JSON", not instructions. - assert isinstance(instructions, str) - openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) - instructions = NOT_GIVEN + if model_request_parameters.output_mode == 'structured_text': + if output_object := model_request_parameters.output_object: + text = {'format': self._map_json_schema(output_object)} + elif self.profile.supports_json_object_response_format: + text = {'format': {'type': 'json_object'}} + + # Without this trick, we'd hit this error: + # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. + # Apparently they're only checking input messages for "JSON", not instructions. + assert isinstance(instructions, str) + openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) + instructions = NOT_GIVEN sampling_settings = ( model_settings diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 5f7119bff..7b7f9b9d0 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -1,28 +1,29 @@ from __future__ import annotations as _annotations -from dataclasses import dataclass, field, fields, replace +from dataclasses import dataclass, fields, replace from textwrap import dedent from typing import Callable, Union from typing_extensions import Self -from pydantic_ai._output import StructuredOutputMode, SupportableOutputMode +from pydantic_ai._output import StructuredOutputMode from ._json_schema import JsonSchemaTransformer -@dataclass +@dataclass(kw_only=True) class ModelProfile: """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used.""" - json_schema_transformer: type[JsonSchemaTransformer] | None = None - """The transformer to use to make JSON schemas for tools and structured output compatible with the model.""" - output_modes: set[SupportableOutputMode] = field(default_factory=lambda: {'tool'}) - """The output modes supported by the model. Essentially all models support `tool` mode, but some also support `json_schema` mode, which needs to be specifically implemented on the model class.""" - default_output_mode: StructuredOutputMode = 'tool' - """The default output mode to use for the model.""" - - prompted_json_output_instructions: str = dedent( + supports_tools: bool = True + """Whether the model supports tools.""" + supports_json_schema_response_format: bool = False + """Whether the model supports the JSON schema response format.""" + supports_json_object_response_format: bool = False + """Whether the model supports the JSON object response format.""" + default_structured_output_mode: StructuredOutputMode = 'tool' + """The default structured output mode to use for the model.""" + structured_output_instructions_template: str = dedent( """ Always respond with a JSON object that's compatible with this schema: @@ -32,6 +33,8 @@ class ModelProfile: """ ) """The instructions to use for prompted JSON output. The schema placeholder will be replaced with the JSON schema for the output.""" + json_schema_transformer: type[JsonSchemaTransformer] | None = None + """The transformer to use to make JSON schemas for tools and structured output compatible with the model.""" @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: diff --git a/pydantic_ai_slim/pydantic_ai/profiles/google.py b/pydantic_ai_slim/pydantic_ai/profiles/google.py index a0cdc61fc..b151a5997 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/google.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/google.py @@ -10,7 +10,11 @@ def google_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Google model.""" - return ModelProfile(json_schema_transformer=GoogleJsonSchemaTransformer, output_modes={'tool', 'json_schema'}) + return ModelProfile( + json_schema_transformer=GoogleJsonSchemaTransformer, + supports_json_schema_response_format=True, + supports_json_object_response_format=True, + ) class GoogleJsonSchemaTransformer(JsonSchemaTransformer): diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index 83e82f9ba..78be9814a 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -21,21 +21,17 @@ class OpenAIModelProfile(ModelProfile): openai_supports_sampling_settings: bool = True """Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models.""" - openai_supports_json_object_response_format: bool = True - """This can be set by a provider or user if the OpenAI-"compatible" API doesn't support the `json_object` `response_format`. - Note that if a model does not support the `json_schema` `response_format`, that value should be removed from `ModelProfile.output_modes`. - """ - def openai_model_profile(model_name: str) -> ModelProfile: """Get the model profile for an OpenAI model.""" is_reasoning_model = model_name.startswith('o') - # `json_schema` output_mode is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later, - # but we leave it in here for all models because the `default_output_mode` is `'tool'`, so `json_schema` is only used - # when the user specifically uses the JsonSchemaOutput marker, so an error from the API is acceptable. + # The JSON schema response format is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later. + # We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `structured_text` is only used + # when the user specifically uses the StructuredTextOutput marker, so an error from the API is acceptable. return OpenAIModelProfile( json_schema_transformer=OpenAIJsonSchemaTransformer, - output_modes={'tool', 'json_schema'}, + supports_json_schema_response_format=True, + supports_json_object_response_format=True, openai_supports_sampling_settings=not is_reasoning_model, ) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 0f7416062..e8131e247 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -12,14 +12,13 @@ from . import _utils, exceptions, messages as _messages, models from ._output import ( - JsonSchemaOutput, OutputDataT, OutputDataT_inv, OutputSchema, OutputValidator, OutputValidatorFunc, PlainTextOutputSchema, - PromptedJsonOutput, + StructuredTextOutput, TextOutput, TextOutputSchema, ToolOutput, @@ -34,8 +33,7 @@ 'OutputDataT_inv', 'ToolOutput', 'TextOutput', - 'JsonSchemaOutput', - 'PromptedJsonOutput', + 'StructuredTextOutput', 'OutputValidatorFunc', ) diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml deleted file mode 100644 index e88afebdf..000000000 --- a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml +++ /dev/null @@ -1,161 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '740' - content-type: - - application/json - host: - - api.anthropic.com - method: POST - parsed_body: - max_tokens: 1024 - messages: - - content: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - type: text - role: user - model: claude-3-5-sonnet-latest - stream: false - system: |+ - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - - tool_choice: - type: auto - tools: - - description: '' - input_schema: - additionalProperties: false - properties: {} - type: object - name: get_user_country - uri: https://api.anthropic.com/v1/messages?beta=true - response: - headers: - connection: - - keep-alive - content-length: - - '397' - content-type: - - application/json - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - content: - - id: toolu_017UryVwtsKsjonhFV3cgV3X - input: {} - name: get_user_country - type: tool_use - id: msg_014CpBKzioMqUyLWrMihpvsz - model: claude-3-5-sonnet-20241022 - role: assistant - stop_reason: tool_use - stop_sequence: null - type: message - usage: - cache_creation_input_tokens: 0 - cache_read_input_tokens: 0 - input_tokens: 459 - output_tokens: 38 - service_tier: standard - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1002' - content-type: - - application/json - host: - - api.anthropic.com - method: POST - parsed_body: - max_tokens: 1024 - messages: - - content: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - type: text - role: user - - content: - - id: toolu_017UryVwtsKsjonhFV3cgV3X - input: {} - name: get_user_country - type: tool_use - role: assistant - - content: - - content: Mexico - is_error: false - tool_use_id: toolu_017UryVwtsKsjonhFV3cgV3X - type: tool_result - role: user - model: claude-3-5-sonnet-latest - stream: false - system: |+ - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - - tool_choice: - type: auto - tools: - - description: '' - input_schema: - additionalProperties: false - properties: {} - type: object - name: get_user_country - uri: https://api.anthropic.com/v1/messages?beta=true - response: - headers: - connection: - - keep-alive - content-length: - - '380' - content-type: - - application/json - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - content: - - text: '{"city": "Mexico City", "country": "Mexico"}' - type: text - id: msg_014JeWCouH6DpdqzMTaBdkpJ - model: claude-3-5-sonnet-20241022 - role: assistant - stop_reason: end_turn - stop_sequence: null - type: message - usage: - cache_creation_input_tokens: 0 - cache_read_input_tokens: 0 - input_tokens: 510 - output_tokens: 17 - service_tier: standard - status: - code: 200 - message: OK -version: 1 -... diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml deleted file mode 100644 index 183daa406..000000000 --- a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml +++ /dev/null @@ -1,66 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1268' - content-type: - - application/json - host: - - api.anthropic.com - method: POST - parsed_body: - max_tokens: 1024 - messages: - - content: - - text: What is the largest city in Mexico? - type: text - role: user - model: claude-3-5-sonnet-latest - stream: false - system: |+ - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - - uri: https://api.anthropic.com/v1/messages?beta=true - response: - headers: - connection: - - keep-alive - content-length: - - '434' - content-type: - - application/json - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - content: - - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' - type: text - id: msg_013ttUi3HCcKt7PkJpoWs5FT - model: claude-3-5-sonnet-20241022 - role: assistant - stop_reason: end_turn - stop_sequence: null - type: message - usage: - cache_creation_input_tokens: 0 - cache_read_input_tokens: 0 - input_tokens: 281 - output_tokens: 31 - service_tier: standard - status: - code: 200 - message: OK -version: 1 -... diff --git a/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml deleted file mode 100644 index d7f14c9ca..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml +++ /dev/null @@ -1,79 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '305' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - response_mime_type: application/json - response_schema: - properties: - city: - type: string - country: - type: string - required: - - city - - country - title: CityLocation - type: object - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '710' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=819 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.00018302639946341515 - content: - parts: - - text: |- - { - "city": "Mexico City", - "country": "Mexico" - } - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: SEVIaJvJHICK7dcP3OzRiQQ - usageMetadata: - candidatesTokenCount: 20 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 20 - promptTokenCount: 17 - promptTokensDetails: - - modality: TEXT - tokenCount: 17 - totalTokenCount: 37 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml deleted file mode 100644 index 3b306d133..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml +++ /dev/null @@ -1,120 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '791' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the primarily language spoken in Mexico? - role: user - generationConfig: - response_mime_type: application/json - response_schema: - properties: - result: - anyOf: - - description: CityLocation - properties: - data: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - enum: - - CityLocation - type: string - required: - - kind - - data - type: object - - description: CountryLanguage - properties: - data: - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - enum: - - CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '800' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=963 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -3.3667640072172103e-06 - content: - parts: - - text: |- - { - "result": { - "data": { - "country": "Mexico", - "language": "Spanish" - }, - "kind": "CountryLanguage" - } - } - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: 2jxIaPucEYCK7dcP3OzRiQQ - usageMetadata: - candidatesTokenCount: 46 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 46 - promptTokenCount: 46 - promptTokensDetails: - - modality: TEXT - tokenCount: 46 - totalTokenCount: 92 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml deleted file mode 100644 index 2268e7f84..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml +++ /dev/null @@ -1,74 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '521' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - response_mime_type: application/json - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '880' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=841 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.007913463882037572 - content: - parts: - - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], - "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: 2zxIaIiLE4CK7dcP3OzRiQQ - usageMetadata: - candidatesTokenCount: 56 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 56 - promptTokenCount: 80 - promptTokensDetails: - - modality: TEXT - tokenCount: 80 - totalTokenCount: 136 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml deleted file mode 100644 index e96fc20d7..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml +++ /dev/null @@ -1,73 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1287' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - response_mime_type: application/json - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: user - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '757' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=823 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.0030997690779191477 - content: - parts: - - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: Wz1IaOH5OdGU7dcPjoS34QI - usageMetadata: - candidatesTokenCount: 27 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 27 - promptTokenCount: 253 - promptTokensDetails: - - modality: TEXT - tokenCount: 253 - totalTokenCount: 280 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml deleted file mode 100644 index f10da3ad7..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml +++ /dev/null @@ -1,157 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '615' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - role: user - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - tools: - functionDeclarations: - - description: '' - name: get_user_country - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '653' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=4501 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - content: - parts: - - functionCall: - args: {} - name: get_user_country - role: model - finishReason: STOP - index: 0 - modelVersion: models/gemini-2.5-pro-preview-05-06 - responseId: rj9IaPTzNdCBqtsPg-GD6QU - usageMetadata: - candidatesTokenCount: 12 - promptTokenCount: 123 - promptTokensDetails: - - modality: TEXT - tokenCount: 123 - thoughtsTokenCount: 318 - totalTokenCount: 453 - status: - code: 200 - message: OK -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '809' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - role: user - - parts: - - functionCall: - args: {} - name: get_user_country - role: model - - parts: - - functionResponse: - name: get_user_country - response: - return_value: Mexico - role: user - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - tools: - functionDeclarations: - - description: '' - name: get_user_country - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '616' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=1823 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - content: - parts: - - text: '{"city": "Mexico City", "country": "Mexico"}' - role: model - finishReason: STOP - index: 0 - modelVersion: models/gemini-2.5-pro-preview-05-06 - responseId: sD9IaOCyLPqumtkP6p_T0AE - usageMetadata: - candidatesTokenCount: 13 - promptTokenCount: 154 - promptTokensDetails: - - modality: TEXT - tokenCount: 154 - thoughtsTokenCount: 94 - totalTokenCount: 261 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_json_schema_output.yaml b/tests/models/cassettes/test_google/test_google_json_schema_output.yaml deleted file mode 100644 index 1d9ae0339..000000000 --- a/tests/models/cassettes/test_google/test_google_json_schema_output.yaml +++ /dev/null @@ -1,86 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '453' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - responseMimeType: application/json - responseSchema: - properties: - city: - type: STRING - country: - type: STRING - property_ordering: - - city - - country - required: - - city - - country - title: CityLocation - type: OBJECT - toolConfig: - functionCallingConfig: - allowedFunctionNames: [] - mode: ANY - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '710' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=780 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.0002309985226020217 - content: - parts: - - text: |- - { - "city": "Mexico City", - "country": "Mexico" - } - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: Gm9HaNr3KteI_NUPmYvnoA8 - usageMetadata: - candidatesTokenCount: 20 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 20 - promptTokenCount: 19 - promptTokensDetails: - - modality: TEXT - tokenCount: 19 - totalTokenCount: 39 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml deleted file mode 100644 index 74dd03c89..000000000 --- a/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml +++ /dev/null @@ -1,138 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1200' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the primarily language spoken in Mexico? - role: user - generationConfig: - responseMimeType: application/json - responseSchema: - description: The final response which ends this conversation - properties: - result: - any_of: - - description: CityLocation - properties: - data: - properties: - city: - type: STRING - country: - type: STRING - property_ordering: - - city - - country - required: - - city - - country - type: OBJECT - kind: - enum: - - CityLocation - type: STRING - property_ordering: - - kind - - data - required: - - kind - - data - type: OBJECT - - description: CountryLanguage - properties: - data: - properties: - country: - type: STRING - language: - type: STRING - property_ordering: - - country - - language - required: - - country - - language - type: OBJECT - kind: - enum: - - CountryLanguage - type: STRING - property_ordering: - - kind - - data - required: - - kind - - data - type: OBJECT - required: - - result - title: final_result - type: OBJECT - toolConfig: - functionCallingConfig: - allowedFunctionNames: [] - mode: ANY - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '800' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=884 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.0002536005138055138 - content: - parts: - - text: |- - { - "result": { - "kind": "CountryLanguage", - "data": { - "country": "Mexico", - "language": "Spanish" - } - } - } - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: W29HaJzGMNGU7dcPjoS34QI - usageMetadata: - candidatesTokenCount: 46 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 46 - promptTokenCount: 64 - promptTokensDetails: - - modality: TEXT - tokenCount: 64 - totalTokenCount: 110 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml deleted file mode 100644 index 3b241acae..000000000 --- a/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml +++ /dev/null @@ -1,78 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '619' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - responseMimeType: application/json - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - toolConfig: - functionCallingConfig: - allowedFunctionNames: [] - mode: ANY - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '879' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=829 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.010130892906870161 - content: - parts: - - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], - "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: 4HlHaK75MdGU7dcPjoS34QI - usageMetadata: - candidatesTokenCount: 56 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 56 - promptTokenCount: 80 - promptTokensDetails: - - modality: TEXT - tokenCount: 80 - totalTokenCount: 136 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml deleted file mode 100644 index 33383473f..000000000 --- a/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml +++ /dev/null @@ -1,77 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1341' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - responseMimeType: application/json - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: user - toolConfig: - functionCallingConfig: - allowedFunctionNames: [] - mode: ANY - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '758' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=734 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.0008548707873732956 - content: - parts: - - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: 6nlHaO_5GdeI_NUPmYvnoA8 - usageMetadata: - candidatesTokenCount: 27 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 27 - promptTokenCount: 241 - promptTokensDetails: - - modality: TEXT - tokenCount: 241 - totalTokenCount: 268 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml deleted file mode 100644 index 976533c66..000000000 --- a/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml +++ /dev/null @@ -1,164 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '658' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - role: user - generationConfig: {} - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - tools: - - functionDeclarations: - - description: '' - name: get_user_country - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '653' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=3776 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - content: - parts: - - functionCall: - args: {} - name: get_user_country - role: model - finishReason: STOP - index: 0 - modelVersion: models/gemini-2.5-pro-preview-05-06 - responseId: FnpHaOqcKrzQz7IPkuLo8QE - usageMetadata: - candidatesTokenCount: 12 - promptTokenCount: 123 - promptTokensDetails: - - modality: TEXT - tokenCount: 123 - thoughtsTokenCount: 266 - totalTokenCount: 401 - status: - code: 200 - message: OK -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '967' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - role: user - - parts: - - functionCall: - args: {} - id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 - name: get_user_country - role: model - - parts: - - functionResponse: - id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 - name: get_user_country - response: - return_value: Mexico - role: user - generationConfig: {} - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - tools: - - functionDeclarations: - - description: '' - name: get_user_country - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '630' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=1888 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - content: - parts: - - text: |- - ```json - {"city": "Mexico City", "country": "Mexico"} - ``` - role: model - finishReason: STOP - index: 0 - modelVersion: models/gemini-2.5-pro-preview-05-06 - responseId: GHpHaOPkI43Qz7IPxt6T2Ac - usageMetadata: - candidatesTokenCount: 18 - promptTokenCount: 154 - promptTokensDetails: - - modality: TEXT - tokenCount: 154 - thoughtsTokenCount: 94 - totalTokenCount: 266 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml b/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml deleted file mode 100644 index ff4477f3d..000000000 --- a/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml +++ /dev/null @@ -1,223 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '522' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: What is the largest city in the user country? - role: user - model: gpt-4o - n: 1 - response_format: - json_schema: - name: result - schema: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: false - type: json_schema - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1066' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '341' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_PkRGedQNRFUzJp2R7dO7avWR - type: function - created: 1746142582 - id: chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3 - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_f5bdcc3276 - usage: - completion_tokens: 12 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 71 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 83 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '753' - content-type: - - application/json - cookie: - - __cf_bm=dOa3_E1SoWV.vbgZ8L7tx8o9S.XyNTE.YS0K0I3JHq4-1746142583-1.0.1.1-0TuvhdYsoD.J1522DBXH0yrAP_M9MlzvlcpyfwQQNZy.KO5gri6ejQ.gFuwLV5hGhuY0W2uI1dN7ZF1lirVHKeEnEz5s_89aJjrMWjyBd8M; - _cfuvid=xQIJVHkOP28w5fPnAvDHPiCRlU7kkNj6iFV87W4u8Ds-1746142583128-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: What is the largest city in the user country? - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_PkRGedQNRFUzJp2R7dO7avWR - type: function - - content: Mexico - role: tool - tool_call_id: call_PkRGedQNRFUzJp2R7dO7avWR - model: gpt-4o - n: 1 - response_format: - json_schema: - name: result - schema: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: false - type: json_schema - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '852' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '553' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: '{"city":"Mexico City","country":"Mexico"}' - refusal: null - role: assistant - created: 1746142583 - id: chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_f5bdcc3276 - usage: - completion_tokens: 15 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 92 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 107 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml deleted file mode 100644 index d01e28ab0..000000000 --- a/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml +++ /dev/null @@ -1,293 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1120' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: What is the largest city in the user country? - role: user - model: gpt-4o - response_format: - json_schema: - description: The final response which ends this conversation - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - required: - - kind - - data - type: object - required: - - result - type: object - type: json_schema - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1068' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '868' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_SIttSeiOistt33Htj4oiHOOX - type: function - created: 1749511286 - id: chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_9bddfca6e2 - usage: - completion_tokens: 11 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 160 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 171 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1351' - content-type: - - application/json - cookie: - - __cf_bm=OFzdr.HrmtC0DNdnfrTQYsK8_PwAVR9GUqjYSCgwtVM-1749511286-1.0.1.1-9_dbth7ET4rzl01UDRTw3fY1nJ20FnMCC0BBmd57gzKF8n5DnNQaI4K1mT.23nn9IUsMyHAZUNn6t1EML3d7GfGJyiLZOxrTWaqacALgzlM; - _cfuvid=f32dQYPsRd6Jc7kg.3hHa1QYAyG8f_aMMXUF.bC6gmY-1749511286914-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: What is the largest city in the user country? - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_SIttSeiOistt33Htj4oiHOOX - type: function - - content: Mexico - role: tool - tool_call_id: call_SIttSeiOistt33Htj4oiHOOX - model: gpt-4o - response_format: - json_schema: - description: The final response which ends this conversation - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - required: - - kind - - data - type: object - required: - - result - type: object - type: json_schema - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '903' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '920' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' - refusal: null - role: assistant - created: 1749511287 - id: chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_9bddfca6e2 - usage: - completion_tokens: 25 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 181 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 206 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml deleted file mode 100644 index 4eed79085..000000000 --- a/tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml +++ /dev/null @@ -1,209 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '690' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - model: gpt-4o - response_format: - type: json_object - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1068' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '569' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_s7oT9jaLAsEqTgvxZTmFh0wB - type: function - created: 1749514895 - id: chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_07871e2ad8 - usage: - completion_tokens: 11 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 109 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 120 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '921' - content-type: - - application/json - cookie: - - __cf_bm=jcec.FXQ2vs1UTNFhcDbuMrvzdFu7d7L1To24_vRFiQ-1749514896-1.0.1.1-PEeul2ZYkvLFmEXXk4Xlgvun2HcuGEJ0UUliLVWKx17kMCjZ8WiZbB2Yavq3RRGlxsJZsAWIVMQQ10Vb_2aqGVtQ2aiYTlnDMX3Ktkuciyk; - _cfuvid=zanrNpp5OAiS0wLKfkW9LCs3qTO2FvIaiBZptR_D2P0-1749514896187-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_s7oT9jaLAsEqTgvxZTmFh0wB - type: function - - content: Mexico - role: tool - tool_call_id: call_s7oT9jaLAsEqTgvxZTmFh0wB - model: gpt-4o - response_format: - type: json_object - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '853' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '718' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: '{"city":"Mexico City","country":"Mexico"}' - refusal: null - role: assistant - created: 1749514896 - id: chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0 - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_07871e2ad8 - usage: - completion_tokens: 11 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 130 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 141 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml deleted file mode 100644 index 3d3ba886a..000000000 --- a/tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml +++ /dev/null @@ -1,209 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1412' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - model: gpt-4o - response_format: - type: json_object - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1068' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '428' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_wJD14IyJ4KKVtjCrGyNCHO09 - type: function - created: 1749514898 - id: chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_9bddfca6e2 - usage: - completion_tokens: 11 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 273 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 284 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1643' - content-type: - - application/json - cookie: - - __cf_bm=gqjIEMZSez95CPkkPVuU_AoDutHrobFMbFPjq43G66M-1749514899-1.0.1.1-5TGB9WajW5pzCRtVtWeQfiwyQUZs1JwWy9qC8VGlgq7s5pQWKerukQtYB7GqNDrdb.1pbtFyt2HZ9xV3YiSbK4H1bZS_hS1CCeoup_3IQW0; - _cfuvid=ZN6eoNau4b.bJ8kvRn2z9R0HgTUd9nOsupKUtLXQowU-1749514899280-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_wJD14IyJ4KKVtjCrGyNCHO09 - type: function - - content: Mexico - role: tool - tool_call_id: call_wJD14IyJ4KKVtjCrGyNCHO09 - model: gpt-4o - response_format: - type: json_object - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '903' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '763' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' - refusal: null - role: assistant - created: 1749514899 - id: chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_9bddfca6e2 - usage: - completion_tokens: 21 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 294 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 315 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml b/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml deleted file mode 100644 index 9fd1b6989..000000000 --- a/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml +++ /dev/null @@ -1,288 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '533' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: What is the largest city in the user country? - role: user - model: gpt-4o - stream: false - text: - format: - name: CityLocation - schema: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1808' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '636' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749516047 - error: null - id: resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - arguments: '{}' - call_id: call_tTAThu8l2S9hNky2krdwijGP - id: fc_68477f0fa7c081a19a525f7c6f180f310b8591d9001d2329 - name: get_user_country - status: completed - type: function_call - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - description: null - name: CityLocation - schema: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 66 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 78 - user: null - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '769' - content-type: - - application/json - cookie: - - __cf_bm=My3TWVEPFsaYcjJ.iWxTB6P67jFSuxSF.n13qHpH9BA-1749516047-1.0.1.1-2bg2ltV1yu2uhfqewI9eEG1ulzfU_gq8pLx9YwHte33BTk2PgxBwaRdyegdEs_dVkAbaCoAPsQRIQmW21QPf_U2Fd1vdibnoExA_.rvTYv8; - _cfuvid=_7XoQBGwU.UsQgiPHVWMTXLLbADtbSwhrO9PY7I_3Dw-1749516047790-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: What is the largest city in the user country? - role: user - - content: '' - role: assistant - - arguments: '{}' - call_id: call_tTAThu8l2S9hNky2krdwijGP - name: get_user_country - type: function_call - - call_id: call_tTAThu8l2S9hNky2krdwijGP - output: Mexico - type: function_call_output - model: gpt-4o - stream: false - text: - format: - name: CityLocation - schema: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1902' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '883' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749516047 - error: null - id: resp_68477f0fde708192989000a62809c6e5020197534e39cc1f - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - content: - - annotations: [] - text: '{"city":"Mexico City","country":"Mexico"}' - type: output_text - id: msg_68477f10846c81929f1e833b0785e6f3020197534e39cc1f - role: assistant - status: completed - type: message - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - description: null - name: CityLocation - schema: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 89 - input_tokens_details: - cached_tokens: 0 - output_tokens: 16 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 105 - user: null - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml deleted file mode 100644 index 9c411f3c7..000000000 --- a/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml +++ /dev/null @@ -1,444 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1143' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: What is the largest city in the user country? - role: user - model: gpt-4o - stream: false - text: - format: - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - type: string - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - additionalProperties: false - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '3657' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '562' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749516048 - error: null - id: resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - arguments: '{}' - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP - id: fc_68477f1168a081a3981e847cd94275080dd57d732903c563 - name: get_user_country - status: completed - type: function_call - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - description: null - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - type: string - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - additionalProperties: false - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 153 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 165 - user: null - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1379' - content-type: - - application/json - cookie: - - __cf_bm=3Nl1ERbtfVAI.dGjzCYYN1u71YD5eEoLU0iCrvPPPL0-1749516049-1.0.1.1-LnI7tJwKr.C_wA15Shsl8pcGd32zrRqqv_9u4S84nXtNCopx1iBIKYDsyMg3u1Z3lJ_1Cd1YVM8uKAMjiKmgoqS8GFQ3Z_vV_Mahvqbi4KA; - _cfuvid=oc_k9l86fnMo2ml.0aop6a3eVJEvjxB0lnxWK0_kJq8-1749516049524-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: What is the largest city in the user country? - role: user - - content: '' - role: assistant - - arguments: '{}' - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP - name: get_user_country - type: function_call - - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP - output: Mexico - type: function_call_output - model: gpt-4o - stream: false - text: - format: - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - type: string - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - additionalProperties: false - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '3800' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '1042' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749516049 - error: null - id: resp_68477f119830819da162aa6e10552035061ad97e2eef7871 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - content: - - annotations: [] - text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' - type: output_text - id: msg_68477f1235b8819d898adc64709c7ebf061ad97e2eef7871 - role: assistant - status: completed - type: message - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - description: null - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - type: string - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - additionalProperties: false - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 176 - input_tokens_details: - cached_tokens: 0 - output_tokens: 26 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 202 - user: null - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml deleted file mode 100644 index 35783c516..000000000 --- a/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml +++ /dev/null @@ -1,248 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '689' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - model: gpt-4o - stream: false - text: - format: - type: json_object - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1408' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '8314' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749561106 - error: null - id: resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - arguments: '{}' - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom - id: fc_68482f1b0ff081a1b37b9170ee740d1e02f8ef7f2fb42b50 - name: get_user_country - status: completed - type: function_call - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - type: json_object - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 107 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 119 - user: null - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '925' - content-type: - - application/json - cookie: - - __cf_bm=8a8rNQQYozQt3YjcA61k6KGe.AlrMMrtcIvKv.D1s1E-1749561115-1.0.1.1-OFcqg8xD2_HdbeO74bU2.mLTqDuiK.ploHeu3_ITPvDlGwrVkwk8erMkHagxk4UDxACCCAygnUs1HL.F4AGjQCaZm1m2eYiMVbLqp0iQh7g; - _cfuvid=wKTRRc2dbdYNYnYwA2vRxVjUvqqkQovvKDwULW0Xwns-1749561115173-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - - content: '' - role: assistant - - arguments: '{}' - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom - name: get_user_country - type: function_call - - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom - output: Mexico - type: function_call_output - model: gpt-4o - stream: false - text: - format: - type: json_object - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1501' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '1098' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749561115 - error: null - id: resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - content: - - annotations: [] - text: '{"city":"Mexico City","country":"Mexico"}' - type: output_text - id: msg_68482f1c159081918a2405f458009a6a044fdb7d019d4115 - role: assistant - status: completed - type: message - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - type: json_object - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 130 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 142 - user: null - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml deleted file mode 100644 index 1a3b4dc00..000000000 --- a/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml +++ /dev/null @@ -1,248 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1455' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - model: gpt-4o - stream: false - text: - format: - type: json_object - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1408' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '11445' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749561117 - error: null - id: resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - arguments: '{}' - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI - id: fc_68482f2889d481a199caa61de7ccb62c08e79646fe74d5ee - name: get_user_country - status: completed - type: function_call - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - type: json_object - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 283 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 295 - user: null - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1691' - content-type: - - application/json - cookie: - - __cf_bm=l95LdgPzGHw0UAhBwse9ADphgmMDWrhYqgiO4gdmSy4-1749561128-1.0.1.1-9zPIs3d5_ipszLpQ7yBaCZEStp8qoRIGFshR93V6n7Z_7AznH0MfuczwuoiaW8e6cEVeVHLhskjXScolO9gP5TmpsaFo37GRuHsHZTRgEeI; - _cfuvid=5L5qtbtbFCFzMmoVufSY.ksn06ay8AFs.UXFEv07pkY-1749561128680-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - - content: '' - role: assistant - - arguments: '{}' - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI - name: get_user_country - type: function_call - - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI - output: Mexico - type: function_call_output - model: gpt-4o - stream: false - text: - format: - type: json_object - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1551' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '2545' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749561128 - error: null - id: resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - content: - - annotations: [] - text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' - type: output_text - id: msg_68482f296bfc81a18665547d4008ab2c06b4ab2d00d03024 - role: assistant - status: completed - type: message - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - type: json_object - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 306 - input_tokens_details: - cached_tokens: 0 - output_tokens: 22 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 328 - user: null - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 54fad1ec6..73207dfad 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -27,7 +27,7 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.result import PromptedJsonOutput, TextOutput, ToolOutput, Usage +from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from ..conftest import IsDatetime, IsNow, IsStr, TestEnv, raise_if_exception, try_import @@ -1301,14 +1301,14 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_anthropic_prompted_json_output(allow_model_requests: None, anthropic_api_key: str): +async def test_anthropic_structured_text_output(allow_model_requests: None, anthropic_api_key: str): m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -1396,7 +1396,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_anthropic_prompted_json_output_multiple(allow_model_requests: None, anthropic_api_key: str): +async def test_anthropic_structured_text_output_multiple(allow_model_requests: None, anthropic_api_key: str): m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) class CityLocation(BaseModel): @@ -1407,7 +1407,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a3810fe0c..a7212f730 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -53,7 +53,7 @@ _GeminiUsageMetaData, ) from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage +from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput, Usage from pydantic_ai.tools import ToolDefinition from ..conftest import ClientWithHandler, IsDatetime, IsNow, IsStr, TestEnv @@ -1582,14 +1582,14 @@ def upcase(text: str) -> str: @pytest.mark.vcr() -async def test_gemini_json_schema_output_with_tools(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_structured_text_output_with_tools(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -1600,7 +1600,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_gemini_json_schema_output(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_structured_text_output(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): @@ -1609,7 +1609,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -1652,7 +1652,7 @@ class CityLocation(BaseModel): @pytest.mark.vcr() -async def test_gemini_json_schema_output_multiple(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_structured_text_output_multiple(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): @@ -1663,7 +1663,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) result = await agent.run('What is the primarily language spoken in Mexico?') assert result.output == snapshot(CountryLanguage(country='Mexico', language='Spanish')) @@ -1711,14 +1711,14 @@ class CountryLanguage(BaseModel): @pytest.mark.vcr() -async def test_gemini_prompted_json_output(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_structured_text_output_with_instructions(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -1763,14 +1763,16 @@ class CityLocation(BaseModel): @pytest.mark.vcr() -async def test_gemini_prompted_json_output_with_tools(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_structured_text_output_with_instructions_with_tools( + allow_model_requests: None, gemini_api_key: str +): m = GeminiModel('gemini-2.5-pro-preview-05-06', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) @agent.tool_plain async def get_user_country() -> str: @@ -1848,7 +1850,9 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_gemini_prompted_json_output_multiple(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_structured_text_output_with_instructions_multiple( + allow_model_requests: None, gemini_api_key: str +): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): @@ -1859,7 +1863,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage], instructions=True)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index e363c330c..503f7c678 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -35,7 +35,7 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage +from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput, Usage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -881,14 +881,14 @@ async def get_user_country() -> str: ) -async def test_google_json_schema_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_structured_text_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -898,7 +898,7 @@ async def get_user_country() -> str: await agent.run('What is the largest city in the user country?') -async def test_google_json_schema_output(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_structured_text_output(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): @@ -907,7 +907,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -948,7 +948,7 @@ class CityLocation(BaseModel): ) -async def test_google_json_schema_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_structured_text_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): @@ -959,7 +959,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) result = await agent.run('What is the primarily language spoken in Mexico?') assert result.output == snapshot(CountryLanguage(country='Mexico', language='Spanish')) @@ -1005,14 +1005,16 @@ class CountryLanguage(BaseModel): ) -async def test_google_prompted_json_output(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_structured_text_output_with_instructions( + allow_model_requests: None, google_provider: GoogleProvider +): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -1055,14 +1057,16 @@ class CityLocation(BaseModel): ) -async def test_google_prompted_json_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_structured_text_output_with_instructions_with_tools( + allow_model_requests: None, google_provider: GoogleProvider +): m = GoogleModel('gemini-2.5-pro-preview-05-06', provider=google_provider) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) @agent.tool_plain async def get_user_country() -> str: @@ -1145,7 +1149,9 @@ async def get_user_country() -> str: ) -async def test_google_prompted_json_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_structured_text_output_with_instructions_multiple( + allow_model_requests: None, google_provider: GoogleProvider +): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): @@ -1156,7 +1162,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage], instructions=True)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 99238690c..d14ddb470 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -34,7 +34,7 @@ from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage +from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition @@ -548,7 +548,7 @@ async def test_stream_structured_json_schema_output(allow_model_requests: None): ] mock_client = MockOpenAI.create_mock_stream(stream) m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) - agent = Agent(m, output_type=JsonSchemaOutput(MyTypedDict)) + agent = Agent(m, output_type=StructuredTextOutput(MyTypedDict)) async with agent.run_stream('') as result: assert not result.is_complete @@ -1962,7 +1962,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_openai_json_schema_output(allow_model_requests: None, openai_api_key: str): +async def test_openai_structured_text_output(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -1971,7 +1971,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -2045,7 +2045,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_openai_json_schema_output_multiple(allow_model_requests: None, openai_api_key: str): +async def test_openai_structured_text_output_multiple(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -2056,7 +2056,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: @@ -2134,14 +2134,14 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_openai_prompted_json_output(allow_model_requests: None, openai_api_key: str): +async def test_openai_structured_text_output_with_instructions(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) @agent.tool_plain async def get_user_country() -> str: @@ -2229,7 +2229,9 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_openai_prompted_json_output_multiple(allow_model_requests: None, openai_api_key: str): +async def test_openai_structured_text_output_with_instructions_multiple( + allow_model_requests: None, openai_api_key: str +): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -2240,7 +2242,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage], instructions=True)) @agent.tool_plain async def get_user_country() -> str: diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index de4c817e2..68e405358 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -21,7 +21,7 @@ UserPromptPart, ) from pydantic_ai.profiles.openai import openai_model_profile -from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput +from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import Usage @@ -565,6 +565,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68477f0b40a8819cb8d55594bc2c232a001fd29e2d5573f7', ), ModelRequest( parts=[ @@ -593,6 +594,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68477f0bfda8819ea65458cd7cc389b801dc81d4bc91f560', ), ModelRequest( parts=[ @@ -647,6 +649,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68477f0d9494819ea4f123bba707c9ee0356a60c98816d6a', ), ModelRequest( parts=[ @@ -668,13 +671,14 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68477f0e2b28819d9c828ef4ee526d6a03434b607c02582d', ), ] ) @pytest.mark.vcr() -async def test_json_schema_output(allow_model_requests: None, openai_api_key: str): +async def test_structured_text_output(allow_model_requests: None, openai_api_key: str): m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -683,7 +687,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -715,6 +719,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329', ), ModelRequest( parts=[ @@ -736,13 +741,14 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68477f0fde708192989000a62809c6e5020197534e39cc1f', ), ] ) @pytest.mark.vcr() -async def test_json_schema_output_multiple(allow_model_requests: None, openai_api_key: str): +async def test_structured_text_output_multiple(allow_model_requests: None, openai_api_key: str): m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -753,7 +759,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: @@ -785,6 +791,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563', ), ModelRequest( parts=[ @@ -810,20 +817,21 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68477f119830819da162aa6e10552035061ad97e2eef7871', ), ] ) @pytest.mark.vcr() -async def test_prompted_json_output(allow_model_requests: None, openai_api_key: str): +async def test_structured_text_output_with_instructions(allow_model_requests: None, openai_api_key: str): m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) @agent.tool_plain async def get_user_country() -> str: @@ -862,6 +870,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50', ), ModelRequest( parts=[ @@ -890,13 +899,14 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115', ), ] ) @pytest.mark.vcr() -async def test_prompted_json_output_multiple(allow_model_requests: None, openai_api_key: str): +async def test_structured_text_output_with_instructions_multiple(allow_model_requests: None, openai_api_key: str): m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -907,7 +917,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage], instructions=True)) @agent.tool_plain async def get_user_country() -> str: @@ -946,6 +956,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee', ), ModelRequest( parts=[ @@ -978,6 +989,7 @@ async def get_user_country() -> str: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024', ), ] ) diff --git a/tests/test_agent.py b/tests/test_agent.py index 0bbe7b1c7..71c34c0ea 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -14,9 +14,8 @@ from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages from pydantic_ai._output import ( - JsonSchemaOutput, OutputSpec, - PromptedJsonOutput, + StructuredTextOutput, TextOutput, TextOutputSchema, ToolOutput, @@ -1263,7 +1262,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) -def test_output_type_prompted_json(): +def test_output_type_structured_text(): def return_city_location(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: text = CityLocation(city='Mexico City', country='Mexico').model_dump_json() return ModelResponse(parts=[TextPart(content=text)]) @@ -1278,8 +1277,8 @@ class CityLocation(BaseModel): agent = Agent( m, - output_type=PromptedJsonOutput( - CityLocation, name='City & Country', description='Description from PromptedJsonOutput' + output_type=StructuredTextOutput( + CityLocation, name='City & Country', description='Description from StructuredTextOutput' ), ) @@ -1297,7 +1296,7 @@ class CityLocation(BaseModel): instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "City & Country", "type": "object", "description": "Description from PromptedJsonOutput. Description from docstring."} +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "City & Country", "type": "object", "description": "Description from StructuredTextOutput. Description from docstring."} Don't include any text or Markdown fencing before or after.\ """, @@ -1312,7 +1311,7 @@ class CityLocation(BaseModel): ) -def test_output_type_prompted_json_with_defs(): +def test_output_type_structured_text_with_defs(): class Foo(BaseModel): """Foo description""" @@ -1348,7 +1347,7 @@ def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent( m, - output_type=PromptedJsonOutput( + output_type=StructuredTextOutput( [FooBar, FooBaz], name='FooBar or FooBaz', description='FooBar or FooBaz description' ), ) @@ -1394,7 +1393,7 @@ def return_city_location(messages: list[ModelMessage], _info: AgentInfo) -> Mode text = '{"city": "Mexico City", "country": "Mexico"}' return ModelResponse(parts=[TextPart(content=text)]) - m = FunctionModel(return_city_location, profile=ModelProfile(output_modes={'tool', 'json_schema'})) + m = FunctionModel(return_city_location, profile=ModelProfile(supports_json_schema_response_format=True)) class CityLocation(BaseModel): city: str @@ -1402,7 +1401,7 @@ class CityLocation(BaseModel): agent = Agent( m, - output_type=JsonSchemaOutput(CityLocation), + output_type=StructuredTextOutput(CityLocation), ) result = agent.run_sync('What is the capital of Mexico?') @@ -1449,7 +1448,7 @@ class CityLocation(BaseModel): ) -def test_output_type_prompted_json_function_with_retry(): +def test_output_type_structured_text_function_with_retry(): class Weather(BaseModel): temperature: float description: str @@ -1469,7 +1468,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content=args_json)]) - agent = Agent(FunctionModel(call_tool), output_type=PromptedJsonOutput(get_weather)) + agent = Agent(FunctionModel(call_tool), output_type=StructuredTextOutput(get_weather, instructions=True)) result = agent.run_sync('New York City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert result.all_messages() == snapshot( @@ -3049,7 +3048,7 @@ def test_unsupported_output_mode(): class Foo(BaseModel): bar: str - agent = Agent('test', output_type=JsonSchemaOutput(Foo)) + agent = Agent('test', output_type=StructuredTextOutput(Foo, instructions=False)) - with pytest.raises(UserError, match="Output mode 'json_schema' is not among supported modes: 'tool'"): + with pytest.raises(UserError, match='Structured output without using instructions is not supported by the model.'): agent.run_sync('Hello') diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 75935d33f..b1fc3c50c 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai._output import PromptedJsonOutput, TextOutput +from pydantic_ai._output import StructuredTextOutput, TextOutput from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( ModelMessage, @@ -940,14 +940,14 @@ def output_validator(data: OutputType | NotOutputType) -> OutputType | NotOutput assert outputs == [OutputType(value='a (validated)'), OutputType(value='a (validated)')] -async def test_stream_output_type_prompted_json(): +async def test_stream_output_type_structured_text(): class CityLocation(BaseModel): city: str country: str | None = None m = TestModel(custom_output_text='{"city": "Mexico City", "country": "Mexico"}') - agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) async with agent.run_stream('') as result: assert not result.is_complete From 5112455f27b97b86cd6435db1d89dd1c7e83725a Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 19:33:49 +0000 Subject: [PATCH 26/90] Add missing cassettes --- ...test_anthropic_structured_text_output.yaml | 161 +++++++ ...ropic_structured_text_output_multiple.yaml | 66 +++ .../test_gemini_structured_text_output.yaml | 79 ++++ ...emini_structured_text_output_multiple.yaml | 120 +++++ ...uctured_text_output_with_instructions.yaml | 74 +++ ...ext_output_with_instructions_multiple.yaml | 73 +++ ...t_output_with_instructions_with_tools.yaml | 157 +++++++ .../test_google_structured_text_output.yaml | 86 ++++ ...oogle_structured_text_output_multiple.yaml | 138 ++++++ ...uctured_text_output_with_instructions.yaml | 78 +++ ...ext_output_with_instructions_multiple.yaml | 77 +++ ...t_output_with_instructions_with_tools.yaml | 164 +++++++ .../test_openai_structured_text_output.yaml | 223 +++++++++ ...penai_structured_text_output_multiple.yaml | 293 ++++++++++++ ...uctured_text_output_with_instructions.yaml | 209 +++++++++ ...ext_output_with_instructions_multiple.yaml | 209 +++++++++ .../test_structured_text_output.yaml | 288 ++++++++++++ .../test_structured_text_output_multiple.yaml | 444 ++++++++++++++++++ ...uctured_text_output_with_instructions.yaml | 248 ++++++++++ ...ext_output_with_instructions_multiple.yaml | 248 ++++++++++ 20 files changed, 3435 insertions(+) create mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output.yaml create mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output_multiple.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_structured_text_output.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_structured_text_output_multiple.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_multiple.yaml create mode 100644 tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_with_tools.yaml create mode 100644 tests/models/cassettes/test_google/test_google_structured_text_output.yaml create mode 100644 tests/models/cassettes/test_google/test_google_structured_text_output_multiple.yaml create mode 100644 tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions.yaml create mode 100644 tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_multiple.yaml create mode 100644 tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_with_tools.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_structured_text_output.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_structured_text_output_multiple.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions.yaml create mode 100644 tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions_multiple.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_structured_text_output.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_structured_text_output_multiple.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions.yaml create mode 100644 tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions_multiple.yaml diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output.yaml new file mode 100644 index 000000000..e88afebdf --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output.yaml @@ -0,0 +1,161 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '740' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '397' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - id: toolu_017UryVwtsKsjonhFV3cgV3X + input: {} + name: get_user_country + type: tool_use + id: msg_014CpBKzioMqUyLWrMihpvsz + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 459 + output_tokens: 38 + service_tier: standard + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1002' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + - content: + - id: toolu_017UryVwtsKsjonhFV3cgV3X + input: {} + name: get_user_country + type: tool_use + role: assistant + - content: + - content: Mexico + is_error: false + tool_use_id: toolu_017UryVwtsKsjonhFV3cgV3X + type: tool_result + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '380' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: '{"city": "Mexico City", "country": "Mexico"}' + type: text + id: msg_014JeWCouH6DpdqzMTaBdkpJ + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 510 + output_tokens: 17 + service_tier: standard + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output_multiple.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output_multiple.yaml new file mode 100644 index 000000000..183daa406 --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output_multiple.yaml @@ -0,0 +1,66 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1268' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in Mexico? + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '434' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + type: text + id: msg_013ttUi3HCcKt7PkJpoWs5FT + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 281 + output_tokens: 31 + service_tier: standard + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output.yaml new file mode 100644 index 000000000..d7f14c9ca --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '305' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + response_schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + title: CityLocation + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '710' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=819 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.00018302639946341515 + content: + parts: + - text: |- + { + "city": "Mexico City", + "country": "Mexico" + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: SEVIaJvJHICK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 20 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 20 + promptTokenCount: 17 + promptTokensDetails: + - modality: TEXT + tokenCount: 17 + totalTokenCount: 37 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_multiple.yaml new file mode 100644 index 000000000..3b306d133 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_multiple.yaml @@ -0,0 +1,120 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '791' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the primarily language spoken in Mexico? + role: user + generationConfig: + response_mime_type: application/json + response_schema: + properties: + result: + anyOf: + - description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + enum: + - CityLocation + type: string + required: + - kind + - data + type: object + - description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + enum: + - CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '800' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=963 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -3.3667640072172103e-06 + content: + parts: + - text: |- + { + "result": { + "data": { + "country": "Mexico", + "language": "Spanish" + }, + "kind": "CountryLanguage" + } + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 2jxIaPucEYCK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 46 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 46 + promptTokenCount: 46 + promptTokensDetails: + - modality: TEXT + tokenCount: 46 + totalTokenCount: 92 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions.yaml b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions.yaml new file mode 100644 index 000000000..2268e7f84 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions.yaml @@ -0,0 +1,74 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '521' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '880' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=841 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.007913463882037572 + content: + parts: + - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], + "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 2zxIaIiLE4CK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 56 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 56 + promptTokenCount: 80 + promptTokensDetails: + - modality: TEXT + tokenCount: 80 + totalTokenCount: 136 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_multiple.yaml new file mode 100644 index 000000000..e96fc20d7 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_multiple.yaml @@ -0,0 +1,73 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1287' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '757' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=823 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0030997690779191477 + content: + parts: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: Wz1IaOH5OdGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 27 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 27 + promptTokenCount: 253 + promptTokensDetails: + - modality: TEXT + tokenCount: 253 + totalTokenCount: 280 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_with_tools.yaml b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_with_tools.yaml new file mode 100644 index 000000000..f10da3ad7 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_with_tools.yaml @@ -0,0 +1,157 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '615' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '653' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=4501 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: rj9IaPTzNdCBqtsPg-GD6QU + usageMetadata: + candidatesTokenCount: 12 + promptTokenCount: 123 + promptTokensDetails: + - modality: TEXT + tokenCount: 123 + thoughtsTokenCount: 318 + totalTokenCount: 453 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '809' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + - parts: + - functionCall: + args: {} + name: get_user_country + role: model + - parts: + - functionResponse: + name: get_user_country + response: + return_value: Mexico + role: user + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '616' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1823 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: sD9IaOCyLPqumtkP6p_T0AE + usageMetadata: + candidatesTokenCount: 13 + promptTokenCount: 154 + promptTokensDetails: + - modality: TEXT + tokenCount: 154 + thoughtsTokenCount: 94 + totalTokenCount: 261 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output.yaml b/tests/models/cassettes/test_google/test_google_structured_text_output.yaml new file mode 100644 index 000000000..1d9ae0339 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_structured_text_output.yaml @@ -0,0 +1,86 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '453' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + responseSchema: + properties: + city: + type: STRING + country: + type: STRING + property_ordering: + - city + - country + required: + - city + - country + title: CityLocation + type: OBJECT + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '710' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=780 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0002309985226020217 + content: + parts: + - text: |- + { + "city": "Mexico City", + "country": "Mexico" + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: Gm9HaNr3KteI_NUPmYvnoA8 + usageMetadata: + candidatesTokenCount: 20 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 20 + promptTokenCount: 19 + promptTokensDetails: + - modality: TEXT + tokenCount: 19 + totalTokenCount: 39 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_structured_text_output_multiple.yaml new file mode 100644 index 000000000..74dd03c89 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_structured_text_output_multiple.yaml @@ -0,0 +1,138 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1200' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the primarily language spoken in Mexico? + role: user + generationConfig: + responseMimeType: application/json + responseSchema: + description: The final response which ends this conversation + properties: + result: + any_of: + - description: CityLocation + properties: + data: + properties: + city: + type: STRING + country: + type: STRING + property_ordering: + - city + - country + required: + - city + - country + type: OBJECT + kind: + enum: + - CityLocation + type: STRING + property_ordering: + - kind + - data + required: + - kind + - data + type: OBJECT + - description: CountryLanguage + properties: + data: + properties: + country: + type: STRING + language: + type: STRING + property_ordering: + - country + - language + required: + - country + - language + type: OBJECT + kind: + enum: + - CountryLanguage + type: STRING + property_ordering: + - kind + - data + required: + - kind + - data + type: OBJECT + required: + - result + title: final_result + type: OBJECT + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '800' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=884 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0002536005138055138 + content: + parts: + - text: |- + { + "result": { + "kind": "CountryLanguage", + "data": { + "country": "Mexico", + "language": "Spanish" + } + } + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: W29HaJzGMNGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 46 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 46 + promptTokenCount: 64 + promptTokensDetails: + - modality: TEXT + tokenCount: 64 + totalTokenCount: 110 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions.yaml b/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions.yaml new file mode 100644 index 000000000..3b241acae --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions.yaml @@ -0,0 +1,78 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '619' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '879' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=829 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.010130892906870161 + content: + parts: + - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], + "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 4HlHaK75MdGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 56 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 56 + promptTokenCount: 80 + promptTokensDetails: + - modality: TEXT + tokenCount: 80 + totalTokenCount: 136 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_multiple.yaml b/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_multiple.yaml new file mode 100644 index 000000000..33383473f --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_multiple.yaml @@ -0,0 +1,77 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1341' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: user + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '758' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=734 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0008548707873732956 + content: + parts: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 6nlHaO_5GdeI_NUPmYvnoA8 + usageMetadata: + candidatesTokenCount: 27 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 27 + promptTokenCount: 241 + promptTokensDetails: + - modality: TEXT + tokenCount: 241 + totalTokenCount: 268 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_with_tools.yaml b/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_with_tools.yaml new file mode 100644 index 000000000..976533c66 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_with_tools.yaml @@ -0,0 +1,164 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '658' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + generationConfig: {} + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '653' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=3776 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: FnpHaOqcKrzQz7IPkuLo8QE + usageMetadata: + candidatesTokenCount: 12 + promptTokenCount: 123 + promptTokensDetails: + - modality: TEXT + tokenCount: 123 + thoughtsTokenCount: 266 + totalTokenCount: 401 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '967' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + - parts: + - functionCall: + args: {} + id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 + name: get_user_country + role: model + - parts: + - functionResponse: + id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 + name: get_user_country + response: + return_value: Mexico + role: user + generationConfig: {} + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '630' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1888 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: |- + ```json + {"city": "Mexico City", "country": "Mexico"} + ``` + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: GHpHaOPkI43Qz7IPxt6T2Ac + usageMetadata: + candidatesTokenCount: 18 + promptTokenCount: 154 + promptTokensDetails: + - modality: TEXT + tokenCount: 154 + thoughtsTokenCount: 94 + totalTokenCount: 266 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_structured_text_output.yaml b/tests/models/cassettes/test_openai/test_openai_structured_text_output.yaml new file mode 100644 index 000000000..ff4477f3d --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_structured_text_output.yaml @@ -0,0 +1,223 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '522' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + response_format: + json_schema: + name: result + schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '341' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_PkRGedQNRFUzJp2R7dO7avWR + type: function + created: 1746142582 + id: chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 71 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 83 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '753' + content-type: + - application/json + cookie: + - __cf_bm=dOa3_E1SoWV.vbgZ8L7tx8o9S.XyNTE.YS0K0I3JHq4-1746142583-1.0.1.1-0TuvhdYsoD.J1522DBXH0yrAP_M9MlzvlcpyfwQQNZy.KO5gri6ejQ.gFuwLV5hGhuY0W2uI1dN7ZF1lirVHKeEnEz5s_89aJjrMWjyBd8M; + _cfuvid=xQIJVHkOP28w5fPnAvDHPiCRlU7kkNj6iFV87W4u8Ds-1746142583128-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_PkRGedQNRFUzJp2R7dO7avWR + type: function + - content: Mexico + role: tool + tool_call_id: call_PkRGedQNRFUzJp2R7dO7avWR + model: gpt-4o + n: 1 + response_format: + json_schema: + name: result + schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '852' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '553' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"city":"Mexico City","country":"Mexico"}' + refusal: null + role: assistant + created: 1746142583 + id: chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 15 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 92 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 107 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_structured_text_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_structured_text_output_multiple.yaml new file mode 100644 index 000000000..d01e28ab0 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_structured_text_output_multiple.yaml @@ -0,0 +1,293 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1120' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + response_format: + json_schema: + description: The final response which ends this conversation + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + required: + - kind + - data + type: object + required: + - result + type: object + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '868' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_SIttSeiOistt33Htj4oiHOOX + type: function + created: 1749511286 + id: chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 160 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 171 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1351' + content-type: + - application/json + cookie: + - __cf_bm=OFzdr.HrmtC0DNdnfrTQYsK8_PwAVR9GUqjYSCgwtVM-1749511286-1.0.1.1-9_dbth7ET4rzl01UDRTw3fY1nJ20FnMCC0BBmd57gzKF8n5DnNQaI4K1mT.23nn9IUsMyHAZUNn6t1EML3d7GfGJyiLZOxrTWaqacALgzlM; + _cfuvid=f32dQYPsRd6Jc7kg.3hHa1QYAyG8f_aMMXUF.bC6gmY-1749511286914-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_SIttSeiOistt33Htj4oiHOOX + type: function + - content: Mexico + role: tool + tool_call_id: call_SIttSeiOistt33Htj4oiHOOX + model: gpt-4o + response_format: + json_schema: + description: The final response which ends this conversation + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + required: + - kind + - data + type: object + required: + - result + type: object + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '903' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '920' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + refusal: null + role: assistant + created: 1749511287 + id: chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 25 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 181 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 206 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions.yaml b/tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions.yaml new file mode 100644 index 000000000..4eed79085 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions.yaml @@ -0,0 +1,209 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '690' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + response_format: + type: json_object + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '569' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_s7oT9jaLAsEqTgvxZTmFh0wB + type: function + created: 1749514895 + id: chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_07871e2ad8 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 109 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 120 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '921' + content-type: + - application/json + cookie: + - __cf_bm=jcec.FXQ2vs1UTNFhcDbuMrvzdFu7d7L1To24_vRFiQ-1749514896-1.0.1.1-PEeul2ZYkvLFmEXXk4Xlgvun2HcuGEJ0UUliLVWKx17kMCjZ8WiZbB2Yavq3RRGlxsJZsAWIVMQQ10Vb_2aqGVtQ2aiYTlnDMX3Ktkuciyk; + _cfuvid=zanrNpp5OAiS0wLKfkW9LCs3qTO2FvIaiBZptR_D2P0-1749514896187-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_s7oT9jaLAsEqTgvxZTmFh0wB + type: function + - content: Mexico + role: tool + tool_call_id: call_s7oT9jaLAsEqTgvxZTmFh0wB + model: gpt-4o + response_format: + type: json_object + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '853' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '718' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"city":"Mexico City","country":"Mexico"}' + refusal: null + role: assistant + created: 1749514896 + id: chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_07871e2ad8 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 130 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 141 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions_multiple.yaml new file mode 100644 index 000000000..3d3ba886a --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions_multiple.yaml @@ -0,0 +1,209 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1412' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + response_format: + type: json_object + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '428' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_wJD14IyJ4KKVtjCrGyNCHO09 + type: function + created: 1749514898 + id: chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 273 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 284 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1643' + content-type: + - application/json + cookie: + - __cf_bm=gqjIEMZSez95CPkkPVuU_AoDutHrobFMbFPjq43G66M-1749514899-1.0.1.1-5TGB9WajW5pzCRtVtWeQfiwyQUZs1JwWy9qC8VGlgq7s5pQWKerukQtYB7GqNDrdb.1pbtFyt2HZ9xV3YiSbK4H1bZS_hS1CCeoup_3IQW0; + _cfuvid=ZN6eoNau4b.bJ8kvRn2z9R0HgTUd9nOsupKUtLXQowU-1749514899280-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_wJD14IyJ4KKVtjCrGyNCHO09 + type: function + - content: Mexico + role: tool + tool_call_id: call_wJD14IyJ4KKVtjCrGyNCHO09 + model: gpt-4o + response_format: + type: json_object + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '903' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '763' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + refusal: null + role: assistant + created: 1749514899 + id: chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 21 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 294 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 315 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_structured_text_output.yaml b/tests/models/cassettes/test_openai_responses/test_structured_text_output.yaml new file mode 100644 index 000000000..9fd1b6989 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_structured_text_output.yaml @@ -0,0 +1,288 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '533' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1808' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '636' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516047 + error: null + id: resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_tTAThu8l2S9hNky2krdwijGP + id: fc_68477f0fa7c081a19a525f7c6f180f310b8591d9001d2329 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 66 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 78 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '769' + content-type: + - application/json + cookie: + - __cf_bm=My3TWVEPFsaYcjJ.iWxTB6P67jFSuxSF.n13qHpH9BA-1749516047-1.0.1.1-2bg2ltV1yu2uhfqewI9eEG1ulzfU_gq8pLx9YwHte33BTk2PgxBwaRdyegdEs_dVkAbaCoAPsQRIQmW21QPf_U2Fd1vdibnoExA_.rvTYv8; + _cfuvid=_7XoQBGwU.UsQgiPHVWMTXLLbADtbSwhrO9PY7I_3Dw-1749516047790-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_tTAThu8l2S9hNky2krdwijGP + name: get_user_country + type: function_call + - call_id: call_tTAThu8l2S9hNky2krdwijGP + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1902' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '883' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516047 + error: null + id: resp_68477f0fde708192989000a62809c6e5020197534e39cc1f + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"city":"Mexico City","country":"Mexico"}' + type: output_text + id: msg_68477f10846c81929f1e833b0785e6f3020197534e39cc1f + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 89 + input_tokens_details: + cached_tokens: 0 + output_tokens: 16 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 105 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_structured_text_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_structured_text_output_multiple.yaml new file mode 100644 index 000000000..9c411f3c7 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_structured_text_output_multiple.yaml @@ -0,0 +1,444 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1143' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '3657' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '562' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516048 + error: null + id: resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + id: fc_68477f1168a081a3981e847cd94275080dd57d732903c563 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 153 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 165 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1379' + content-type: + - application/json + cookie: + - __cf_bm=3Nl1ERbtfVAI.dGjzCYYN1u71YD5eEoLU0iCrvPPPL0-1749516049-1.0.1.1-LnI7tJwKr.C_wA15Shsl8pcGd32zrRqqv_9u4S84nXtNCopx1iBIKYDsyMg3u1Z3lJ_1Cd1YVM8uKAMjiKmgoqS8GFQ3Z_vV_Mahvqbi4KA; + _cfuvid=oc_k9l86fnMo2ml.0aop6a3eVJEvjxB0lnxWK0_kJq8-1749516049524-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + name: get_user_country + type: function_call + - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '3800' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1042' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516049 + error: null + id: resp_68477f119830819da162aa6e10552035061ad97e2eef7871 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + type: output_text + id: msg_68477f1235b8819d898adc64709c7ebf061ad97e2eef7871 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 176 + input_tokens_details: + cached_tokens: 0 + output_tokens: 26 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 202 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions.yaml b/tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions.yaml new file mode 100644 index 000000000..35783c516 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions.yaml @@ -0,0 +1,248 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '689' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1408' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '8314' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561106 + error: null + id: resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + id: fc_68482f1b0ff081a1b37b9170ee740d1e02f8ef7f2fb42b50 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 107 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 119 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '925' + content-type: + - application/json + cookie: + - __cf_bm=8a8rNQQYozQt3YjcA61k6KGe.AlrMMrtcIvKv.D1s1E-1749561115-1.0.1.1-OFcqg8xD2_HdbeO74bU2.mLTqDuiK.ploHeu3_ITPvDlGwrVkwk8erMkHagxk4UDxACCCAygnUs1HL.F4AGjQCaZm1m2eYiMVbLqp0iQh7g; + _cfuvid=wKTRRc2dbdYNYnYwA2vRxVjUvqqkQovvKDwULW0Xwns-1749561115173-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + name: get_user_country + type: function_call + - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1501' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1098' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561115 + error: null + id: resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"city":"Mexico City","country":"Mexico"}' + type: output_text + id: msg_68482f1c159081918a2405f458009a6a044fdb7d019d4115 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 130 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 142 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions_multiple.yaml new file mode 100644 index 000000000..1a3b4dc00 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions_multiple.yaml @@ -0,0 +1,248 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1455' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1408' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '11445' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561117 + error: null + id: resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + id: fc_68482f2889d481a199caa61de7ccb62c08e79646fe74d5ee + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 283 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 295 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1691' + content-type: + - application/json + cookie: + - __cf_bm=l95LdgPzGHw0UAhBwse9ADphgmMDWrhYqgiO4gdmSy4-1749561128-1.0.1.1-9zPIs3d5_ipszLpQ7yBaCZEStp8qoRIGFshR93V6n7Z_7AznH0MfuczwuoiaW8e6cEVeVHLhskjXScolO9gP5TmpsaFo37GRuHsHZTRgEeI; + _cfuvid=5L5qtbtbFCFzMmoVufSY.ksn06ay8AFs.UXFEv07pkY-1749561128680-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + name: get_user_country + type: function_call + - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1551' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '2545' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561128 + error: null + id: resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + type: output_text + id: msg_68482f296bfc81a18665547d4008ab2c06b4ab2d00d03024 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 306 + input_tokens_details: + cached_tokens: 0 + output_tokens: 22 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 328 + user: null + status: + code: 200 + message: OK +version: 1 From 416cc7d31b067d2016ad9de6ac3e11ddb590e858 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 19:36:07 +0000 Subject: [PATCH 27/90] Can't use dataclass kw_only on 3.9 --- pydantic_ai_slim/pydantic_ai/profiles/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 7b7f9b9d0..dfd7c3c45 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -11,7 +11,7 @@ from ._json_schema import JsonSchemaTransformer -@dataclass(kw_only=True) +@dataclass class ModelProfile: """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used.""" From 4b0e5cf8ea83364aa8d7f3a0d28cd808f3bc4238 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 20:29:27 +0000 Subject: [PATCH 28/90] Improve test coverage --- pydantic_ai_slim/pydantic_ai/_output.py | 13 +--- pydantic_ai_slim/pydantic_ai/_utils.py | 14 ++-- .../pydantic_ai/models/__init__.py | 4 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 4 +- tests/test_agent.py | 73 ++++++++++++++++++- 5 files changed, 82 insertions(+), 26 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index d71707391..a218e9f3d 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -162,7 +162,6 @@ class StructuredTextOutput(Generic[OutputDataT]): """ name: str | None description: str | None - strict: bool | None def __init__( self, @@ -170,13 +169,11 @@ def __init__( *, name: str | None = None, description: str | None = None, - strict: bool | None = True, instructions: bool | str | None = None, ): self.outputs = flatten_output_spec(type_) self.name = name self.description = description - self.strict = strict self.instructions = instructions @@ -271,7 +268,6 @@ def build( output_spec.outputs, name=output_spec.name, description=output_spec.description, - strict=output_spec.strict, ), instructions=output_spec.instructions, ) @@ -535,7 +531,7 @@ def instructions(self, template: str) -> str: template = self._instructions if '{schema}' not in template: - raise UserError("Structured output instructions template must contain a '{schema}' placeholder.") + template = '\n\n'.join([template, '{schema}']) object_def = self.object_def schema = object_def.json_schema.copy() @@ -597,7 +593,7 @@ def find_tool( @dataclass(init=False) -class ToolOrTextOutputSchema(PlainTextOutputSchema[OutputDataT], ToolOutputSchema[OutputDataT]): +class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchema[OutputDataT]): def __init__( self, processor: PlainTextOutputProcessor[OutputDataT] | None, @@ -610,11 +606,6 @@ def __init__( def mode(self) -> OutputMode: return 'tool_or_text' - def raise_if_unsupported(self, profile: ModelProfile) -> None: - """Raise an error if the mode is not supported by the model.""" - if not profile.supports_tools: - raise UserError('Output tools are not supported by the model.') - @dataclass class OutputObjectDefinition: diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 051ec08f7..f5b52a9c7 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -12,7 +12,7 @@ from datetime import datetime, timezone from functools import partial from types import GenericAlias -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload from anyio.to_thread import run_sync from pydantic import BaseModel, TypeAdapter @@ -341,11 +341,9 @@ def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, s # Recursively update refs in properties if 'properties' in s: - props: dict[str, Any] = s['properties'] + props: dict[str, dict[str, Any]] = s['properties'] for prop in props.values(): - if isinstance(prop, dict): - prop = cast(dict[str, Any], prop) - _update_mapped_json_schema_refs(prop, name_mapping) + _update_mapped_json_schema_refs(prop, name_mapping) # Handle arrays if 'items' in s and isinstance(s['items'], dict): @@ -354,16 +352,14 @@ def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, s if 'prefixItems' in s: prefix_items: list[dict[str, Any]] = s['prefixItems'] for item in prefix_items: - if isinstance(item, dict): - _update_mapped_json_schema_refs(item, name_mapping) + _update_mapped_json_schema_refs(item, name_mapping) # Handle unions for union_type in ['anyOf', 'oneOf']: if union_type in s: union_items: list[dict[str, Any]] = s[union_type] for item in union_items: - if isinstance(item, dict): - _update_mapped_json_schema_refs(item, name_mapping) + _update_mapped_json_schema_refs(item, name_mapping) def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]: diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index e55f85048..eea15ec0e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -724,8 +724,6 @@ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinit def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition): - schema_transformer = transformer(o.json_schema, strict=o.strict) + schema_transformer = transformer(o.json_schema, strict=True) son_schema = schema_transformer.walk() - if o.strict is None: - o = replace(o, strict=schema_transformer.is_strict_compatible) return replace(o, json_schema=son_schema) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index eace317b9..b1c54f6fc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -284,7 +284,7 @@ async def _completions_create( if model_request_parameters.output_mode == 'structured_text': if output_object := model_request_parameters.output_object: response_format = self._map_json_schema(output_object) - elif self.profile.supports_json_object_response_format: + elif self.profile.supports_json_object_response_format: # pragma: no branch response_format = {'type': 'json_object'} sampling_settings = ( @@ -701,7 +701,7 @@ async def _responses_create( if model_request_parameters.output_mode == 'structured_text': if output_object := model_request_parameters.output_object: text = {'format': self._map_json_schema(output_object)} - elif self.profile.supports_json_object_response_format: + elif self.profile.supports_json_object_response_format: # pragma: no branch text = {'format': {'type': 'json_object'}} # Without this trick, we'd hit this error: diff --git a/tests/test_agent.py b/tests/test_agent.py index 71c34c0ea..f760b3a5c 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1262,6 +1262,28 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_default_structured_output_mode(): + def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart(content='hello')]) + + tool_model = FunctionModel(hello, profile=ModelProfile(default_structured_output_mode='tool')) + structured_text_model = FunctionModel( + hello, + profile=ModelProfile( + supports_json_schema_response_format=True, default_structured_output_mode='structured_text' + ), + ) + + class Foo(BaseModel): + bar: str + + tool_agent = Agent(tool_model, output_type=Foo) + assert tool_agent._output_schema.mode == 'tool' # type: ignore[reportPrivateUsage] + + structured_text_agent = Agent(structured_text_model, output_type=Foo) + assert structured_text_agent._output_schema.mode == 'structured_text' # type: ignore[reportPrivateUsage] + + def test_output_type_structured_text(): def return_city_location(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: text = CityLocation(city='Mexico City', country='Mexico').model_dump_json() @@ -1311,6 +1333,45 @@ class CityLocation(BaseModel): ) +def test_output_type_structured_text_with_custom_instructions(): + def return_foo(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + text = Foo(bar='baz').model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + m = FunctionModel(return_foo) + + class Foo(BaseModel): + bar: str + + agent = Agent(m, output_type=StructuredTextOutput(Foo, instructions='Gimme some JSON:')) + + result = agent.run_sync('What is the capital of Mexico?') + assert result.output == snapshot(Foo(bar='baz')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Gimme some JSON: + +{"properties": {"bar": {"type": "string"}}, "required": ["bar"], "title": "Foo", "type": "object"}\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"bar":"baz"}')], + usage=Usage(requests=1, request_tokens=56, response_tokens=4, total_tokens=60), + model_name='function:return_foo:', + timestamp=IsDatetime(), + ), + ] + ) + + def test_output_type_structured_text_with_defs(): class Foo(BaseModel): """Foo description""" @@ -3048,7 +3109,17 @@ def test_unsupported_output_mode(): class Foo(BaseModel): bar: str - agent = Agent('test', output_type=StructuredTextOutput(Foo, instructions=False)) + def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart('hello')]) + + model = FunctionModel(hello, profile=ModelProfile(supports_tools=False, supports_json_schema_response_format=False)) + + agent = Agent(model, output_type=StructuredTextOutput(Foo, instructions=False)) with pytest.raises(UserError, match='Structured output without using instructions is not supported by the model.'): agent.run_sync('Hello') + + agent = Agent(model, output_type=ToolOutput(Foo)) + + with pytest.raises(UserError, match='Output tools are not supported by the model.'): + agent.run_sync('Hello') From 094920f7cc090e214024cf5230a0c97128aa7e0a Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 20:42:17 +0000 Subject: [PATCH 29/90] Improve test coverage --- pydantic_ai_slim/pydantic_ai/models/openai.py | 15 +++++---------- tests/test_agent.py | 4 ++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index b1c54f6fc..2391cfd59 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -427,18 +427,14 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, ) - @staticmethod - def _map_json_schema(o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: + def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage] 'type': 'json_schema', - 'json_schema': { - 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, - 'schema': o.json_schema, - }, + 'json_schema': {'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, 'schema': o.json_schema, 'strict': True}, } if o.description: response_format_param['json_schema']['description'] = o.description - if o.strict: # pragma: no branch + if OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: response_format_param['json_schema']['strict'] = o.strict return response_format_param @@ -827,8 +823,7 @@ def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam: type='function_call', ) - @staticmethod - def _map_json_schema(o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam: + def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam: response_format_param: responses.ResponseFormatTextJSONSchemaConfigParam = { 'type': 'json_schema', 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, @@ -836,7 +831,7 @@ def _map_json_schema(o: OutputObjectDefinition) -> responses.ResponseFormatTextJ } if o.description: response_format_param['description'] = o.description - if o.strict: # pragma: no branch + if OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: # pragma: no branch response_format_param['strict'] = o.strict return response_format_param diff --git a/tests/test_agent.py b/tests/test_agent.py index f760b3a5c..bcccb5517 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1264,7 +1264,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: def test_default_structured_output_mode(): def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: - return ModelResponse(parts=[TextPart(content='hello')]) + return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover tool_model = FunctionModel(hello, profile=ModelProfile(default_structured_output_mode='tool')) structured_text_model = FunctionModel( @@ -3110,7 +3110,7 @@ class Foo(BaseModel): bar: str def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: - return ModelResponse(parts=[TextPart('hello')]) + return ModelResponse(parts=[TextPart('hello')]) # pragma: no cover model = FunctionModel(hello, profile=ModelProfile(supports_tools=False, supports_json_schema_response_format=False)) From 9f617069755ee24d123f0bebb50be082074761b2 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 20:52:08 +0000 Subject: [PATCH 30/90] Improve test coverage --- pydantic_ai_slim/pydantic_ai/models/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 2391cfd59..d16e3423d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -434,7 +434,7 @@ def _map_json_schema(self, o: OutputObjectDefinition) -> chat.completion_create_ } if o.description: response_format_param['json_schema']['description'] = o.description - if OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: + if OpenAIModelProfile.from_profile(self.profile).openai_supports_strict_tool_definition: # pragma: no branch response_format_param['json_schema']['strict'] = o.strict return response_format_param From 9f51387c4916a2b24f21afc52a2e1f5c1c7b9cac Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 21:11:48 +0000 Subject: [PATCH 31/90] Remove unnecessary coverage ignores --- pydantic_ai_slim/pydantic_ai/profiles/__init__.py | 2 +- pydantic_ai_slim/pydantic_ai/profiles/google.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index dfd7c3c45..8d87f973d 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -32,7 +32,7 @@ class ModelProfile: Don't include any text or Markdown fencing before or after. """ ) - """The instructions to use for prompted JSON output. The schema placeholder will be replaced with the JSON schema for the output.""" + """The instructions to use for prompted JSON output. The '{schema}' placeholder will be replaced with the JSON schema for the output.""" json_schema_transformer: type[JsonSchemaTransformer] | None = None """The transformer to use to make JSON schemas for tools and structured output compatible with the model.""" diff --git a/pydantic_ai_slim/pydantic_ai/profiles/google.py b/pydantic_ai_slim/pydantic_ai/profiles/google.py index b151a5997..de85ea8d5 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/google.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/google.py @@ -51,7 +51,7 @@ def transform(self, schema: JsonSchema) -> JsonSchema: schema.pop('title', None) schema.pop('default', None) schema.pop('$schema', None) - if (const := schema.pop('const', None)) is not None: # pragma: no cover + if (const := schema.pop('const', None)) is not None: # Gemini doesn't support const, but it does support enum with a single value schema['enum'] = [const] schema.pop('discriminator', None) From 9a1e6283eba2e90ca817bbe6edf97efbd8ebdb10 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 21:52:03 +0000 Subject: [PATCH 32/90] Remove unnecessary coverage ignore --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- pydantic_ai_slim/pydantic_ai/profiles/_json_schema.py | 2 +- tests/models/test_gemini.py | 2 +- tests/models/test_google.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index c4c28b37b..d4da97ff1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -221,7 +221,7 @@ async def _make_request( if model_request_parameters.output_mode == 'structured_text': if output_object := model_request_parameters.output_object: if tools: - raise UserError('Google does not support JSON schema output and tools at the same time.') + raise UserError('Gemini does not support JSON schema output and tools at the same time.') generation_config['response_mime_type'] = 'application/json' generation_config['response_schema'] = self._map_response_schema(output_object) diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 5eeedd703..32ce697cc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -255,7 +255,7 @@ async def _generate_content( if model_request_parameters.output_mode == 'structured_text': if output_object := model_request_parameters.output_object: if tools: - raise UserError('Google does not support JSON schema output and tools at the same time.') + raise UserError('Gemini does not support JSON schema output and tools at the same time.') response_mime_type = 'application/json' response_schema = self._map_response_schema(output_object) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/_json_schema.py b/pydantic_ai_slim/pydantic_ai/profiles/_json_schema.py index 3feda2995..cf13b41e9 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/_json_schema.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/_json_schema.py @@ -174,7 +174,7 @@ def _simplify_nullable_union(cases: list[JsonSchema]) -> list[JsonSchema]: # they are both null, so just return one of them return [cases[0]] - return cases # pragma: no cover + return cases class InlineDefsJsonSchemaTransformer(JsonSchemaTransformer): diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a7212f730..00dd69229 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -1595,7 +1595,7 @@ class CityLocation(BaseModel): async def get_user_country() -> str: return 'Mexico' # pragma: no cover - with pytest.raises(UserError, match='Google does not support JSON schema output and tools at the same time.'): + with pytest.raises(UserError, match='Gemini does not support JSON schema output and tools at the same time.'): await agent.run('What is the largest city in the user country?') diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 503f7c678..d522b306c 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -894,7 +894,7 @@ class CityLocation(BaseModel): async def get_user_country() -> str: return 'Mexico' # pragma: no cover - with pytest.raises(UserError, match='Google does not support JSON schema output and tools at the same time.'): + with pytest.raises(UserError, match='Gemini does not support JSON schema output and tools at the same time.'): await agent.run('What is the largest city in the user country?') From 2b5fa819a81e103b726a4803fe1addaf11e67cc0 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 22:17:14 +0000 Subject: [PATCH 33/90] Add docs --- docs/output.md | 93 +++++++++++++++++++++++++++++++++++++++--- tests/test_examples.py | 3 ++ 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/docs/output.md b/docs/output.md index 7eac789c4..5d7793d44 100644 --- a/docs/output.md +++ b/docs/output.md @@ -4,7 +4,7 @@ The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or Both `AgentRunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. -A run ends when a plain text response is received (assuming no output type is specified or `str` is one of the allowed options), or when the model responds with one of the structured output types by calling a special output tool. A run can also be cancelled if usage limits are exceeded, see [Usage Limits](agents.md#usage-limits). +A run ends when a plain text response is received (assuming no output type is specified or `str` is one of the allowed options), or when the model responds with one of the structured output types. A run can also be cancelled if usage limits are exceeded, see [Usage Limits](agents.md#usage-limits). Here's an example using a Pydantic model as the `output_type`, forcing the model to respond with data matching our specification: @@ -31,12 +31,12 @@ _(This example is complete, it can be run "as is")_ ## Output data {#structured-output} -The [`Agent`][pydantic_ai.Agent] class constructor takes an `output_type` argument that takes one or more types or [output functions](#output-functions). It supports both type unions and lists of types and functions. +The [`Agent`][pydantic_ai.Agent] class constructor takes an `output_type` argument that takes one or more types or [output functions](#output-functions). It supports simple scalar types, list and dict types, dataclasses and Pydantic models, as well as type unions -- generally everything supported as type hints in a Pydantic model. Multiple types and functions can also be provided in a list. -When no output type is specified, or when the output type is `str` or a union or list of types including `str`, the model is allowed to respond with plain text, and this text is used as the output data. -If `str` is not among the allowed output types, the model is not allowed to respond with plain text and is forced to return structured data (or arguments to an output function). +By default, Pydantic AI leverages the model's tool calling capability to make it return structured data. When multiple output types are specified (in a union or list), each member is registered with the model as a separate output tool in order to reduce the complexity of the schema and maximise the chances a model will respond correctly. This has been shown to work well across a wide range of models. If you'd like to change the names of the output tools, use a model's native structured output feature, or pass the output schema to the model in its [instructions](agents.md#instructions), you can use an [output mode](#output-modes) marker class. -If the output type is a union or list with multiple members, each member (except for `str`, if it is a member) is registered with the model as a separate output tool in order to reduce the complexity of the tool schemas and maximise the chances a model will respond correctly. +When no output type is specified, or when `str` is among the output types, the model is allowed to respond with plain text, and this text is used as the output data. +If `str` is not among the output types, the model is forced to return structured data (or arguments to an output function). If the output type schema is not of type `"object"` (e.g. it's `int` or `list[int]`), the output type is wrapped in a single element object, so the schema of all tools registered with the model are object schemas. @@ -88,7 +88,7 @@ print(result.output) _(This example is complete, it can be run "as is")_ -Here's an example of using a union return type, for which PydanticAI will register multiple tools and wraps non-object schemas in an object: +Here's an example of using a union return type, which will register multiple output tools and wrap non-object schemas in an object: ```python {title="colors_or_sizes.py"} from typing import Union @@ -235,11 +235,92 @@ explanation = 'I am not equipped to provide travel information, such as flights """ ``` +### Output modes + +Pydantic AI implements three different methods to get a model to output structured data: + +1. **Tool output**, where the output JSON schema is provided to the model as the arguments schema of a special output tool. This is the default as it's supported by virtually all models and has been shown to work very well. +2. **Structured text output** using a model API's native **"forced JSON Schema"** feature (aka "JSON Schema response format" or Structured Output), where the model is forced to only output text matching the provided JSON schema. This is not supported by all models (most notably Claude) and sometimes comes with restrictions (for example, Gemini cannot use tools at the same time as JSON output). +3. **Structured text output** using [**instructions**](agents.md#instructions), where the model is prompted to output text matching the provided JSON schema and it's up to the model to interpret those instructions correctly. If the model API supports the "JSON Mode" (aka "JSON Object response format") feature to force the model to output valid JSON, this is enabled, but it's still up to the model to abide by the schema. This is supported by all models, but is the least reliable approach as the model is not forced to match the schema. Pydantic AI will validate the returned structured data and tell the model to try again if validation fails, but if the model is not intelligent enough this may not be sufficient. + +By default, Pydantic AI will use the tool output mode and register a separate output tool for each output type (or function). If you'd like to change the name of the output tool, pass a custom description to aid the model, or turn on or off strict mode, you can wrap the type(s) in the [`ToolOutput`][pydantic_ai.result.ToolOutput] marker class and provide the appropriate arguments. Note that by default, the description is taken from the docstring specified on a Pydantic model or output function, so specifying it using the marker class is typically not necessary. + +If you'd like to use the structured text output mode, you can wrap the type(s) in the [`StructuredTextOutput`][pydantic_ai.result.StructuredTextOutput] marker class that also lets you specify a name and description if the name and docstring of the type or function are not sufficient. Additionally, it supports an `instructions` argument that is `None` by default, indicating that Pydantic AI should choose the best strategy supported by the model: forced JSON schema or instructions. You can set it to `False` to never use instructions (which will result in an error if the forced JSON schema feature is not supported by the model) or `True` to always use instructions. If you'd like to use a custom instructions template instead of the [default](pydantic_ai.profiles.ModelProfile.structured_output_instructions_template), you can pass a string with a `{schema}` placeholder that will be replaced with the actual JSON schema. + +Finally, if you provide an [output function](#output-functions) that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like to the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.result.TextOutput] marker class. If desired, this marker class can be used alongside one or more `ToolOutput` marker classes (or unmarked types or functions) in a list provided to `output_type`. + +Here's an example of these 3 output mode marker classes in action: + +```python +from pydantic import BaseModel + +from pydantic_ai import Agent +from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput + + +class Fruit(BaseModel): + name: str + color: str + + +class Vehicle(BaseModel): + name: str + wheels: int + + +class Device(BaseModel): + name: str + kind: str + + +agent = Agent( + 'openai:gpt-4o', + output_type=[ + ToolOutput(Fruit, name='return_fruit'), + ToolOutput(Vehicle, name='return_vehicle'), + ToolOutput(Device, name='return_device'), + ], +) +result = agent.run_sync('What is a banana?') +print(result.output) +#> name='banana' color='yellow' + +agent = Agent( + 'openai:gpt-4o', + output_type=StructuredTextOutput([Fruit, Vehicle, Device]), +) +result = agent.run_sync('What is a Ford Explorer?') +print(result.output) +#> name='Ford Explorer' wheels=4 + +agent = Agent( + 'openai:gpt-4o', + output_type=StructuredTextOutput([Fruit, Vehicle, Device], instructions=True), +) +result = agent.run_sync('What is a MacBook?') +print(result.output) +#> name='MacBook' kind='laptop' + + +def split_into_words(text: str) -> list[str]: + return text.split() + + +agent = Agent( + 'openai:gpt-4o', + output_type=TextOutput(split_into_words), +) +result = agent.run_sync('Who was Albert Einstein?') +print(result.output) +#> ['Albert', 'Einstein', 'was', 'a', 'German-born', 'theoretical', 'physicist.'] +``` + ### Output validators {#output-validator-functions} Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator. If you want to implement separate validation logic for different output types, it's recommended to use [output functions](#output-functions) instead, to save you from having to do `isinstance` checks inside the output validator. +If you want the model to output plain text, do your own processing or validation, and then have the agent's final output be the result of your function, it's recommended to use an [output function](#output-functions) with the `TextOutput` [output mode](#output-modes) marker class. Here's a simplified variant of the [SQL Generation example](examples/sql-gen.md): diff --git a/tests/test_examples.py b/tests/test_examples.py index 98edf8a9e..204bbe931 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -432,6 +432,9 @@ async def list_tools() -> list[None]: 'explanation': 'I am not equipped to provide travel information, such as flights from Amsterdam to Mexico City.' }, ), + 'What is a banana?': ToolCallPart(tool_name='return_fruit', args={'name': 'banana', 'color': 'yellow'}), + 'What is a Ford Explorer?': '{"result": {"kind": "Vehicle", "data": {"name": "Ford Explorer", "wheels": 4}}}', + 'What is a MacBook?': '{"result": {"kind": "Device", "data": {"name": "MacBook", "kind": "laptop"}}}', } tool_responses: dict[tuple[str, str], str] = { From 6c4662b309b40d92a7b3d24603e6a4a45480ec26 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 22:25:36 +0000 Subject: [PATCH 34/90] Fix docs refs --- docs/api/result.md | 3 +++ docs/output.md | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/api/result.md b/docs/api/result.md index 29d08e9cf..8ed64e965 100644 --- a/docs/api/result.md +++ b/docs/api/result.md @@ -6,3 +6,6 @@ members: - OutputDataT - StreamedRunResult + - ToolOutput + - StructuredTextOutput + - TextOutput diff --git a/docs/output.md b/docs/output.md index 5d7793d44..878ff24bb 100644 --- a/docs/output.md +++ b/docs/output.md @@ -245,7 +245,7 @@ Pydantic AI implements three different methods to get a model to output structur By default, Pydantic AI will use the tool output mode and register a separate output tool for each output type (or function). If you'd like to change the name of the output tool, pass a custom description to aid the model, or turn on or off strict mode, you can wrap the type(s) in the [`ToolOutput`][pydantic_ai.result.ToolOutput] marker class and provide the appropriate arguments. Note that by default, the description is taken from the docstring specified on a Pydantic model or output function, so specifying it using the marker class is typically not necessary. -If you'd like to use the structured text output mode, you can wrap the type(s) in the [`StructuredTextOutput`][pydantic_ai.result.StructuredTextOutput] marker class that also lets you specify a name and description if the name and docstring of the type or function are not sufficient. Additionally, it supports an `instructions` argument that is `None` by default, indicating that Pydantic AI should choose the best strategy supported by the model: forced JSON schema or instructions. You can set it to `False` to never use instructions (which will result in an error if the forced JSON schema feature is not supported by the model) or `True` to always use instructions. If you'd like to use a custom instructions template instead of the [default](pydantic_ai.profiles.ModelProfile.structured_output_instructions_template), you can pass a string with a `{schema}` placeholder that will be replaced with the actual JSON schema. +If you'd like to use the structured text output mode, you can wrap the type(s) in the [`StructuredTextOutput`][pydantic_ai.result.StructuredTextOutput] marker class that also lets you specify a name and description if the name and docstring of the type or function are not sufficient. Additionally, it supports an `instructions` argument that is `None` by default, indicating that Pydantic AI should choose the best strategy supported by the model: forced JSON schema or instructions. You can set it to `False` to never use instructions (which will result in an error if the forced JSON schema feature is not supported by the model) or `True` to always use instructions. If you'd like to use a custom instructions template instead of the [default][pydantic_ai.profiles.ModelProfile.structured_output_instructions_template], you can pass a string with a `{schema}` placeholder that will be replaced with the actual JSON schema. Finally, if you provide an [output function](#output-functions) that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like to the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.result.TextOutput] marker class. If desired, this marker class can be used alongside one or more `ToolOutput` marker classes (or unmarked types or functions) in a list provided to `output_type`. From 3ed3431440cc717c8b5f41b68462870610c22f24 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 13 Jun 2025 22:28:17 +0000 Subject: [PATCH 35/90] Fix nested list in docs --- docs/output.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/output.md b/docs/output.md index 878ff24bb..8988f511e 100644 --- a/docs/output.md +++ b/docs/output.md @@ -48,6 +48,7 @@ Structured outputs (like tools) use Pydantic to build the JSON schema used for t Static type checkers like pyright and mypy will do their best the infer the agent's output type from the `output_type` you've specified, but they're not always able to do so correctly when you provide functions or multiple types in a union or list, even though PydanticAI will behave correctly. When this happens, your type checker will complain even when you're confident you've passed a valid `output_type`, and you'll need to help the type checker by explicitly specifying the generic parameters on the `Agent` constructor. This is shown in the second example below and the output functions example further down. Specifically, there are three valid uses of `output_type` where you'll need to do this: + 1. When using a union of types, e.g. `output_type=Foo | Bar` or in older Python, `output_type=Union[Foo, Bar]`. Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands in Python 3.15, type checkers do not consider these a valid value for `output_type`. In addition to the generic parameters on the `Agent` constructor, you'll need to add `# type: ignore` to the line that passes the union to `output_type`. 2. With mypy: When using a list, as a functionally equivalent alternative to a union, or because you're passing in [output functions](#output-functions). Pyright does handle this correctly, and we've filed [an issue](https://github.com/python/mypy/issues/19142) with mypy to try and get this fixed. 3. With mypy: when using an async output function. Pyright does handle this correctly, and we've filed [an issue](https://github.com/python/mypy/issues/19143) with mypy to try and get this fixed. From a86d7d4469cc63e09b9ecc232369fb24b05fe249 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 21:33:23 +0000 Subject: [PATCH 36/90] Split StructuredTextOutput into ModelStructuredOutput and PromptedStructuredOutput, move to public output module --- docs/api/result.md | 3 +- docs/output.md | 35 ++- pydantic_ai_slim/pydantic_ai/__init__.py | 8 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 10 +- pydantic_ai_slim/pydantic_ai/_cli.py | 2 +- .../pydantic_ai/_function_schema.py | 4 +- pydantic_ai_slim/pydantic_ai/_output.py | 249 ++++++------------ pydantic_ai_slim/pydantic_ai/agent.py | 9 +- .../pydantic_ai/models/__init__.py | 3 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 20 +- pydantic_ai_slim/pydantic_ai/models/google.py | 20 +- pydantic_ai_slim/pydantic_ai/models/openai.py | 40 +-- pydantic_ai_slim/pydantic_ai/output.py | 167 ++++++++++++ .../pydantic_ai/profiles/__init__.py | 15 +- .../pydantic_ai/profiles/google.py | 4 +- .../pydantic_ai/profiles/openai.py | 8 +- pydantic_ai_slim/pydantic_ai/result.py | 10 +- ...anthropic_prompted_structured_output.yaml} | 0 ..._prompted_structured_output_multiple.yaml} | 0 ... test_gemini_model_structured_output.yaml} | 0 ...ini_model_structured_output_multiple.yaml} | 0 ...st_gemini_prompted_structured_output.yaml} | 0 ..._prompted_structured_output_multiple.yaml} | 0 ...rompted_structured_output_with_tools.yaml} | 0 ... test_google_model_structured_output.yaml} | 0 ...gle_model_structured_output_multiple.yaml} | 0 ...st_google_prompted_structured_output.yaml} | 0 ..._prompted_structured_output_multiple.yaml} | 0 ...rompted_structured_output_with_tools.yaml} | 0 ... test_openai_model_structured_output.yaml} | 0 ...nai_model_structured_output_multiple.yaml} | 0 ...st_openai_prompted_structured_output.yaml} | 0 ..._prompted_structured_output_multiple.yaml} | 0 ...yaml => test_model_structured_output.yaml} | 0 ...est_model_structured_output_multiple.yaml} | 0 ...l => test_prompted_structured_output.yaml} | 0 ..._prompted_structured_output_multiple.yaml} | 0 tests/models/test_anthropic.py | 26 +- tests/models/test_gemini.py | 33 ++- tests/models/test_google.py | 33 ++- tests/models/test_openai.py | 25 +- tests/models/test_openai_responses.py | 18 +- tests/test_agent.py | 45 ++-- tests/test_streaming.py | 6 +- tests/test_tools.py | 3 +- tests/typed_agent.py | 2 +- 46 files changed, 444 insertions(+), 354 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/output.py rename tests/models/cassettes/test_anthropic/{test_anthropic_structured_text_output.yaml => test_anthropic_prompted_structured_output.yaml} (100%) rename tests/models/cassettes/test_anthropic/{test_anthropic_structured_text_output_multiple.yaml => test_anthropic_prompted_structured_output_multiple.yaml} (100%) rename tests/models/cassettes/test_gemini/{test_gemini_structured_text_output.yaml => test_gemini_model_structured_output.yaml} (100%) rename tests/models/cassettes/test_gemini/{test_gemini_structured_text_output_multiple.yaml => test_gemini_model_structured_output_multiple.yaml} (100%) rename tests/models/cassettes/test_gemini/{test_gemini_structured_text_output_with_instructions.yaml => test_gemini_prompted_structured_output.yaml} (100%) rename tests/models/cassettes/test_gemini/{test_gemini_structured_text_output_with_instructions_multiple.yaml => test_gemini_prompted_structured_output_multiple.yaml} (100%) rename tests/models/cassettes/test_gemini/{test_gemini_structured_text_output_with_instructions_with_tools.yaml => test_gemini_prompted_structured_output_with_tools.yaml} (100%) rename tests/models/cassettes/test_google/{test_google_structured_text_output.yaml => test_google_model_structured_output.yaml} (100%) rename tests/models/cassettes/test_google/{test_google_structured_text_output_multiple.yaml => test_google_model_structured_output_multiple.yaml} (100%) rename tests/models/cassettes/test_google/{test_google_structured_text_output_with_instructions.yaml => test_google_prompted_structured_output.yaml} (100%) rename tests/models/cassettes/test_google/{test_google_structured_text_output_with_instructions_multiple.yaml => test_google_prompted_structured_output_multiple.yaml} (100%) rename tests/models/cassettes/test_google/{test_google_structured_text_output_with_instructions_with_tools.yaml => test_google_prompted_structured_output_with_tools.yaml} (100%) rename tests/models/cassettes/test_openai/{test_openai_structured_text_output.yaml => test_openai_model_structured_output.yaml} (100%) rename tests/models/cassettes/test_openai/{test_openai_structured_text_output_multiple.yaml => test_openai_model_structured_output_multiple.yaml} (100%) rename tests/models/cassettes/test_openai/{test_openai_structured_text_output_with_instructions.yaml => test_openai_prompted_structured_output.yaml} (100%) rename tests/models/cassettes/test_openai/{test_openai_structured_text_output_with_instructions_multiple.yaml => test_openai_prompted_structured_output_multiple.yaml} (100%) rename tests/models/cassettes/test_openai_responses/{test_structured_text_output.yaml => test_model_structured_output.yaml} (100%) rename tests/models/cassettes/test_openai_responses/{test_structured_text_output_multiple.yaml => test_model_structured_output_multiple.yaml} (100%) rename tests/models/cassettes/test_openai_responses/{test_structured_text_output_with_instructions.yaml => test_prompted_structured_output.yaml} (100%) rename tests/models/cassettes/test_openai_responses/{test_structured_text_output_with_instructions_multiple.yaml => test_prompted_structured_output_multiple.yaml} (100%) diff --git a/docs/api/result.md b/docs/api/result.md index 8ed64e965..85593a221 100644 --- a/docs/api/result.md +++ b/docs/api/result.md @@ -7,5 +7,6 @@ - OutputDataT - StreamedRunResult - ToolOutput - - StructuredTextOutput + - ModelStructuredOutput + - PromptedStructuredOutput - TextOutput diff --git a/docs/output.md b/docs/output.md index 8988f511e..0151064bf 100644 --- a/docs/output.md +++ b/docs/output.md @@ -36,7 +36,7 @@ The [`Agent`][pydantic_ai.Agent] class constructor takes an `output_type` argume By default, Pydantic AI leverages the model's tool calling capability to make it return structured data. When multiple output types are specified (in a union or list), each member is registered with the model as a separate output tool in order to reduce the complexity of the schema and maximise the chances a model will respond correctly. This has been shown to work well across a wide range of models. If you'd like to change the names of the output tools, use a model's native structured output feature, or pass the output schema to the model in its [instructions](agents.md#instructions), you can use an [output mode](#output-modes) marker class. When no output type is specified, or when `str` is among the output types, the model is allowed to respond with plain text, and this text is used as the output data. -If `str` is not among the output types, the model is forced to return structured data (or arguments to an output function). +If `str` is not among the output types, the model is forced to return structured data or call an output function. If the output type schema is not of type `"object"` (e.g. it's `int` or `list[int]`), the output type is wrapped in a single element object, so the schema of all tools registered with the model are object schemas. @@ -133,7 +133,7 @@ from typing import Union from pydantic import BaseModel from pydantic_ai import Agent, ModelRetry, RunContext -from pydantic_ai._output import ToolRetryError +from pydantic_ai.output import ToolRetryError from pydantic_ai.exceptions import UnexpectedModelBehavior @@ -240,23 +240,22 @@ explanation = 'I am not equipped to provide travel information, such as flights Pydantic AI implements three different methods to get a model to output structured data: -1. **Tool output**, where the output JSON schema is provided to the model as the arguments schema of a special output tool. This is the default as it's supported by virtually all models and has been shown to work very well. -2. **Structured text output** using a model API's native **"forced JSON Schema"** feature (aka "JSON Schema response format" or Structured Output), where the model is forced to only output text matching the provided JSON schema. This is not supported by all models (most notably Claude) and sometimes comes with restrictions (for example, Gemini cannot use tools at the same time as JSON output). -3. **Structured text output** using [**instructions**](agents.md#instructions), where the model is prompted to output text matching the provided JSON schema and it's up to the model to interpret those instructions correctly. If the model API supports the "JSON Mode" (aka "JSON Object response format") feature to force the model to output valid JSON, this is enabled, but it's still up to the model to abide by the schema. This is supported by all models, but is the least reliable approach as the model is not forced to match the schema. Pydantic AI will validate the returned structured data and tell the model to try again if validation fails, but if the model is not intelligent enough this may not be sufficient. +1. **Tool output**, where the output JSON schema is provided to the model as the parameters schema of a special output tool. This is the default as it's supported by virtually all models and has been shown to work very well. +2. **Structured text output** using a model API's native **"forced JSON Schema"** feature (aka "JSON Schema response format" or "Structured Outputs"), where the model is forced to only output text matching the provided JSON schema. This is currently only supported by OpenAI and Gemini and sometimes comes with restrictions (for example, Gemini cannot use tools at the same time as JSON output). +3. **Structured text output** using [**instructions**](agents.md#instructions), where the model is prompted to output text matching the provided JSON schema and it's up to the model to interpret those instructions correctly. If the model API supports the "JSON Mode" (aka "JSON Object response format") feature to force the model to output valid JSON, this is enabled, but it's still up to the model to abide by the schema. This is usable with all models, but is the least reliable approach as the model is not forced to match the schema. Pydantic AI will validate the returned structured data and tell the model to try again if validation fails, but if the model is not intelligent enough this may not be sufficient. -By default, Pydantic AI will use the tool output mode and register a separate output tool for each output type (or function). If you'd like to change the name of the output tool, pass a custom description to aid the model, or turn on or off strict mode, you can wrap the type(s) in the [`ToolOutput`][pydantic_ai.result.ToolOutput] marker class and provide the appropriate arguments. Note that by default, the description is taken from the docstring specified on a Pydantic model or output function, so specifying it using the marker class is typically not necessary. +By default, Pydantic AI will use the tool output mode and register a separate output tool for each output type (or function). If you'd like to change the name of the output tool, pass a custom description to aid the model, or turn on or off strict mode, you can wrap the type(s) in the [`ToolOutput`][pydantic_ai.output.ToolOutput] marker class and provide the appropriate arguments. Note that by default, the description is taken from the docstring specified on a Pydantic model or output function, so specifying it using the marker class is typically not necessary. -If you'd like to use the structured text output mode, you can wrap the type(s) in the [`StructuredTextOutput`][pydantic_ai.result.StructuredTextOutput] marker class that also lets you specify a name and description if the name and docstring of the type or function are not sufficient. Additionally, it supports an `instructions` argument that is `None` by default, indicating that Pydantic AI should choose the best strategy supported by the model: forced JSON schema or instructions. You can set it to `False` to never use instructions (which will result in an error if the forced JSON schema feature is not supported by the model) or `True` to always use instructions. If you'd like to use a custom instructions template instead of the [default][pydantic_ai.profiles.ModelProfile.structured_output_instructions_template], you can pass a string with a `{schema}` placeholder that will be replaced with the actual JSON schema. +If you'd like to use the structured text output mode, you can wrap the type(s) in the [`StructuredTextOutput`][pydantic_ai.output.StructuredTextOutput] marker class that also lets you specify a name and description if the name and docstring of the type or function are not sufficient. Additionally, it supports an `instructions` argument that is `None` by default, indicating that Pydantic AI should choose the best strategy supported by the model: forced JSON schema or instructions. You can set it to `False` to never use instructions (which will result in an error if the forced JSON schema feature is not supported by the model) or `True` to always use instructions. If you'd like to use a custom instructions template instead of the [default][pydantic_ai.profiles.ModelProfile.prompted_structured_output_template], you can pass a string with a `{schema}` placeholder that will be replaced with the actual JSON schema. -Finally, if you provide an [output function](#output-functions) that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like to the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.result.TextOutput] marker class. If desired, this marker class can be used alongside one or more `ToolOutput` marker classes (or unmarked types or functions) in a list provided to `output_type`. +Finally, if you provide an [output function](#output-functions) that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like to the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.output.TextOutput] marker class. If desired, this marker class can be used alongside one or more `ToolOutput` marker classes (or unmarked types or functions) in a list provided to `output_type`. Here's an example of these 3 output mode marker classes in action: ```python from pydantic import BaseModel -from pydantic_ai import Agent -from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput +from pydantic_ai import Agent, ModelStructuredOutput, PromptedStructuredOutput, TextOutput, ToolOutput class Fruit(BaseModel): @@ -283,24 +282,24 @@ agent = Agent( ], ) result = agent.run_sync('What is a banana?') -print(result.output) -#> name='banana' color='yellow' +print(repr(result.output)) +#> Fruit(name='banana', color='yellow') agent = Agent( 'openai:gpt-4o', - output_type=StructuredTextOutput([Fruit, Vehicle, Device]), + output_type=ModelStructuredOutput([Fruit, Vehicle, Device]), ) result = agent.run_sync('What is a Ford Explorer?') -print(result.output) -#> name='Ford Explorer' wheels=4 +print(repr(result.output)) +#> Vehicle(name='Ford Explorer', wheels=4) agent = Agent( 'openai:gpt-4o', - output_type=StructuredTextOutput([Fruit, Vehicle, Device], instructions=True), + output_type=PromptedStructuredOutput([Fruit, Vehicle, Device]), ) result = agent.run_sync('What is a MacBook?') -print(result.output) -#> name='MacBook' kind='laptop' +print(repr(result.output)) +#> Device(name='MacBook', kind='laptop') def split_into_words(text: str) -> list[str]: diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 5e3c9aaa6..e5c0a3414 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -12,7 +12,7 @@ ) from .format_prompt import format_as_xml from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl -from .result import StructuredTextOutput, ToolOutput +from .output import ModelStructuredOutput, PromptedStructuredOutput, TextOutput, ToolOutput from .tools import RunContext, Tool __all__ = ( @@ -41,9 +41,11 @@ # tools 'Tool', 'RunContext', - # result + # output 'ToolOutput', - 'StructuredTextOutput', + 'ModelStructuredOutput', + 'PromptedStructuredOutput', + 'TextOutput', # format_prompt 'format_as_xml', ) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 8b818b300..b588f35d5 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -17,7 +17,7 @@ from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage -from .result import OutputDataT +from .output import OutputDataT from .settings import ModelSettings, merge_model_settings from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc @@ -271,17 +271,15 @@ async def add_mcp_server_tools(server: MCPServer) -> None: function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] output_schema = ctx.deps.output_schema - model_profile = ctx.deps.model.profile output_tools = [] output_object = None if isinstance(output_schema, _output.ToolOutputSchema): output_tools = output_schema.tool_defs() - elif isinstance(output_schema, _output.StructuredTextOutputSchema): - if not output_schema.use_instructions(model_profile): - output_object = output_schema.object_def + elif isinstance(output_schema, _output.ModelStructuredOutputSchema): + output_object = output_schema.object_def - # Both ToolOrTextOutputSchema and StructuredTextOutputSchema inherit from TextOutputSchema + # ToolOrTextOutputSchema, ModelStructuredOutputSchema, and PromptedStructuredOutputSchema all inherit from TextOutputSchema allow_text_output = isinstance(output_schema, _output.TextOutputSchema) return models.ModelRequestParameters( diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index 7e041696b..feb007327 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -14,7 +14,7 @@ from typing_inspection.introspection import get_literal_values -from pydantic_ai.result import OutputDataT +from pydantic_ai.output import OutputDataT from pydantic_ai.tools import AgentDepsT from . import __version__ diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index 6065681eb..f5db7df07 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -19,13 +19,11 @@ from pydantic_core import SchemaValidator, core_schema from typing_extensions import get_origin -from pydantic_ai.tools import RunContext - from ._griffe import doc_descriptions from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor if TYPE_CHECKING: - from .tools import DocstringFormat, ObjectJsonSchema + from .tools import DocstringFormat, ObjectJsonSchema, RunContext __all__ = ('function_schema',) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index a218e9f3d..f0a7db7fe 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -9,12 +9,24 @@ from pydantic import TypeAdapter, ValidationError from pydantic_core import SchemaValidator -from typing_extensions import TypeAliasType, TypedDict, TypeVar, assert_never, get_args, get_origin -from typing_inspection import typing_objects -from typing_inspection.introspection import is_union_origin +from typing_extensions import TypedDict, TypeVar, assert_never from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry, UserError +from .output import ( + ModelStructuredOutput, + OutputDataT, + OutputMode, + OutputSpec, + OutputTypeOrFunction, + PromptedStructuredOutput, + StructuredOutputMode, + TextOutput, + TextOutputFunction, + ToolOutput, + ToolRetryError, + _flatten_output_spec, # pyright: ignore[reportPrivateUsage] +) from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition if TYPE_CHECKING: @@ -34,8 +46,6 @@ At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would resolve these potential variance issues. """ -OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) -"""Covariant type variable for the result data type of a run.""" OutputValidatorFunc = Union[ Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv], @@ -106,109 +116,6 @@ async def validate( return result_data -class ToolRetryError(Exception): - """Internal exception used to signal a `ToolRetry` message should be returned to the LLM.""" - - def __init__(self, tool_retry: _messages.RetryPromptPart): - self.tool_retry = tool_retry - super().__init__() - - -@dataclass(init=False) -class ToolOutput(Generic[OutputDataT]): - """Marker class to use tools for outputs, and customize the tool.""" - - output: OutputTypeOrFunction[OutputDataT] - name: str | None - description: str | None - max_retries: int | None - strict: bool | None - - def __init__( - self, - type_: OutputTypeOrFunction[OutputDataT], - *, - name: str | None = None, - description: str | None = None, - max_retries: int | None = None, - strict: bool | None = None, - ): - self.output = type_ - self.name = name - self.description = description - self.max_retries = max_retries - self.strict = strict - - -@dataclass -class TextOutput(Generic[OutputDataT]): - """Marker class to use text output with an output function.""" - - output_function: TextOutputFunction[OutputDataT] - - -@dataclass(init=False) -class StructuredTextOutput(Generic[OutputDataT]): - """Marker class to use structured text output for outputs.""" - - outputs: Sequence[OutputTypeOrFunction[OutputDataT]] - instructions: bool | str | None - """Whether to use the model's built-in functionality for structured output matching a JSON schema, or to pass the JSON schema to the model as instructions. - - If `None`, we'll use the model's built-in functionality if it's supported, and otherwise pass the JSON schema to the model as instructions. - If `True`, we'll pass the JSON schema to the model using the instructions template specified on the model's profile. - If `False`, we'll use the model's built-in functionality and raise an error if it's not supported. - If `str`, we'll pass the JSON schema to the model using the specified instructions template. - """ - name: str | None - description: str | None - - def __init__( - self, - type_: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], - *, - name: str | None = None, - description: str | None = None, - instructions: bool | str | None = None, - ): - self.outputs = flatten_output_spec(type_) - self.name = name - self.description = description - self.instructions = instructions - - -T_co = TypeVar('T_co', covariant=True) - -OutputTypeOrFunction = TypeAliasType( - 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,) -) -OutputSpec = TypeAliasType( - 'OutputSpec', - Union[ - OutputTypeOrFunction[T_co], - ToolOutput[T_co], - TextOutput[T_co], - StructuredTextOutput[T_co], - Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], - ], - type_params=(T_co,), -) - -TextOutputFunction = TypeAliasType( - 'TextOutputFunction', - Union[ - Callable[[RunContext, str], Union[Awaitable[T_co], T_co]], - Callable[[str], Union[Awaitable[T_co], T_co]], - ], - type_params=(T_co,), -) - -OutputMode = Literal['text', 'tool', 'structured_text', 'tool_or_text'] -"""All output modes.""" -StructuredOutputMode = Literal['tool', 'structured_text'] -"""Output modes that can be used for structured output. Used by ModelProfile.default_structured_output_mode""" - - class BaseOutputSchema(ABC, Generic[OutputDataT]): @abstractmethod def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: @@ -262,20 +169,28 @@ def build( if output_spec is str: return PlainTextOutputSchema() - if isinstance(output_spec, StructuredTextOutput): - return StructuredTextOutputSchema( + if isinstance(output_spec, ModelStructuredOutput): + return ModelStructuredOutputSchema( + cls._build_processor( + output_spec.outputs, + name=output_spec.name, + description=output_spec.description, + ) + ) + elif isinstance(output_spec, PromptedStructuredOutput): + return PromptedStructuredOutputSchema( cls._build_processor( output_spec.outputs, name=output_spec.name, description=output_spec.description, ), - instructions=output_spec.instructions, + template=output_spec.template, ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = [] - for output in flatten_output_spec(output_spec): + for output in _flatten_output_spec(output_spec): if output is str: text_outputs.append(cast(type[str], output)) elif isinstance(output, TextOutput): @@ -368,7 +283,7 @@ def _build_processor( description: str | None = None, strict: bool | None = None, ) -> ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]: - outputs = flatten_output_spec(outputs) + outputs = _flatten_output_spec(outputs) if len(outputs) == 1: return ObjectOutputProcessor(output=outputs[0], name=name, description=description, strict=strict) @@ -402,8 +317,10 @@ def __init__( self._tools = tools def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: - if mode == 'structured_text': - return StructuredTextOutputSchema(self.processor) + if mode == 'model_structured': + return ModelStructuredOutputSchema(self.processor) + elif mode == 'prompted_structured': + return PromptedStructuredOutputSchema(self.processor) elif mode == 'tool': return ToolOutputSchema(self.tools) else: @@ -467,23 +384,26 @@ async def process( ) -@dataclass(init=False) -class StructuredTextOutputSchema(TextOutputSchema[OutputDataT]): +@dataclass +class StructuredTextOutputSchema(TextOutputSchema[OutputDataT], ABC): processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] - _instructions: bool | str | None = None - - def __init__( - self, - processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], - instructions: bool | str | None = None, - ): - self.processor = processor - self._instructions = instructions @property def object_def(self) -> OutputObjectDefinition: return self.processor.object_def + +@dataclass +class ModelStructuredOutputSchema(StructuredTextOutputSchema[OutputDataT]): + @property + def mode(self) -> OutputMode: + return 'model_structured' + + def raise_if_unsupported(self, profile: ModelProfile) -> None: + """Raise an error if the mode is not supported by the model.""" + if not profile.supports_structured_output: + raise UserError('Structured output is not supported by the model.') + async def process( self, text: str, @@ -502,33 +422,26 @@ async def process( Returns: Either the validated output data (left) or a retry message (right). """ - text = _utils.strip_markdown_fences(text) - return await self.processor.process( text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) + +@dataclass +class PromptedStructuredOutputSchema(StructuredTextOutputSchema[OutputDataT]): + template: str | None = None + @property def mode(self) -> OutputMode: - return 'structured_text' + return 'prompted_structured' def raise_if_unsupported(self, profile: ModelProfile) -> None: """Raise an error if the mode is not supported by the model.""" - if self._instructions is False and not profile.supports_json_schema_response_format: - raise UserError('Structured output without using instructions is not supported by the model.') - - def use_instructions(self, profile: ModelProfile) -> bool: - if isinstance(self._instructions, bool): - return self._instructions - elif isinstance(self._instructions, str): - return True - else: - return not profile.supports_json_schema_response_format + pass - def instructions(self, template: str) -> str: + def instructions(self, default_template: str) -> str: """Get instructions to tell model to output JSON matching the schema.""" - if isinstance(self._instructions, str): - template = self._instructions + template = self.template or default_template if '{schema}' not in template: template = '\n\n'.join([template, '{schema}']) @@ -542,6 +455,30 @@ def instructions(self, template: str) -> str: return template.format(schema=json.dumps(schema)) + async def process( + self, + text: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + """Validate an output message. + + Args: + text: The output text to validate. + run_context: The current run context. + allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. + + Returns: + Either the validated output data (left) or a retry message (right). + """ + text = _utils.strip_markdown_fences(text) + + return await self.processor.process( + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + @dataclass(init=False) class ToolOutputSchema(OutputSchema[OutputDataT]): @@ -972,31 +909,3 @@ async def process( raise # pragma: lax no cover else: return output - - -def get_union_args(tp: Any) -> tuple[Any, ...]: - """Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple.""" - if typing_objects.is_typealiastype(tp): - tp = tp.__value__ - - origin = get_origin(tp) - if is_union_origin(origin): - return get_args(tp) - else: - return () - - -def flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: - outputs: Sequence[T] - if isinstance(output_spec, Sequence): - outputs = output_spec - else: - outputs = (output_spec,) - - outputs_flat: list[T] = [] - for output in outputs: - if union_types := get_union_args(output): - outputs_flat.extend(union_types) - else: - outputs_flat.append(output) - return outputs_flat diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 1738cb557..4d58f83f4 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -31,7 +31,8 @@ ) from ._agent_graph import HistoryProcessor from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model -from .result import FinalResult, OutputDataT, StreamedRunResult +from .output import OutputDataT +from .result import FinalResult, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -687,10 +688,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: ] model_profile = model_used.profile - if isinstance(output_schema, _output.StructuredTextOutputSchema) and output_schema.use_instructions( - model_profile - ): - instructions = output_schema.instructions(model_profile.structured_output_instructions_template) + if isinstance(output_schema, _output.PromptedStructuredOutputSchema): + instructions = output_schema.instructions(model_profile.prompted_structured_output_template) parts.append(instructions) parts = [p for p in parts if p] diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index eea15ec0e..3d76ae332 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -20,10 +20,11 @@ from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec -from .._output import OutputMode, OutputObjectDefinition +from .._output import OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl +from ..output import OutputMode from ..profiles._json_schema import JsonSchemaTransformer from ..settings import ModelSettings from ..tools import ToolDefinition diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4af7247d9..91c87cf6d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -218,15 +218,17 @@ async def _make_request( request_data['toolConfig'] = tool_config generation_config = _settings_to_generation_config(model_settings) - if model_request_parameters.output_mode == 'structured_text': - if output_object := model_request_parameters.output_object: - if tools: - raise UserError('Gemini does not support JSON schema output and tools at the same time.') - - generation_config['response_mime_type'] = 'application/json' - generation_config['response_schema'] = self._map_response_schema(output_object) - elif not tools: - generation_config['response_mime_type'] = 'application/json' + if model_request_parameters.output_mode == 'model_structured': + if tools: + raise UserError('Gemini does not support structured output and tools at the same time.') + + generation_config['response_mime_type'] = 'application/json' + + output_object = model_request_parameters.output_object + assert output_object is not None + generation_config['response_schema'] = self._map_response_schema(output_object) + elif model_request_parameters.output_mode == 'prompted_structured' and not tools: + generation_config['response_mime_type'] = 'application/json' if generation_config: request_data['generationConfig'] = generation_config diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index ec44f8480..6d6aee54f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -252,15 +252,17 @@ async def _generate_content( response_mime_type = None response_schema = None - if model_request_parameters.output_mode == 'structured_text': - if output_object := model_request_parameters.output_object: - if tools: - raise UserError('Gemini does not support JSON schema output and tools at the same time.') - - response_mime_type = 'application/json' - response_schema = self._map_response_schema(output_object) - elif not tools: - response_mime_type = 'application/json' + if model_request_parameters.output_mode == 'model_structured': + if tools: + raise UserError('Gemini does not support structured output and tools at the same time.') + + response_mime_type = 'application/json' + + output_object = model_request_parameters.output_object + assert output_object is not None + response_schema = self._map_response_schema(output_object) + elif model_request_parameters.output_mode == 'prompted_structured' and not tools: + response_mime_type = 'application/json' tool_config = self._get_tool_config(model_request_parameters, tools) system_instruction, contents = await self._map_messages(messages) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index d16e3423d..caff8a07a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -281,11 +281,14 @@ async def _completions_create( openai_messages = await self._map_messages(messages) response_format: chat.completion_create_params.ResponseFormat | None = None - if model_request_parameters.output_mode == 'structured_text': - if output_object := model_request_parameters.output_object: - response_format = self._map_json_schema(output_object) - elif self.profile.supports_json_object_response_format: # pragma: no branch - response_format = {'type': 'json_object'} + if model_request_parameters.output_mode == 'model_structured': + output_object = model_request_parameters.output_object + assert output_object is not None + response_format = self._map_json_schema(output_object) + elif ( + model_request_parameters.output_mode == 'prompted_structured' and self.profile.supports_json_output + ): # pragma: no branch + response_format = {'type': 'json_object'} sampling_settings = ( model_settings @@ -694,18 +697,21 @@ async def _responses_create( reasoning = self._get_reasoning(model_settings) text: responses.ResponseTextConfigParam | None = None - if model_request_parameters.output_mode == 'structured_text': - if output_object := model_request_parameters.output_object: - text = {'format': self._map_json_schema(output_object)} - elif self.profile.supports_json_object_response_format: # pragma: no branch - text = {'format': {'type': 'json_object'}} - - # Without this trick, we'd hit this error: - # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. - # Apparently they're only checking input messages for "JSON", not instructions. - assert isinstance(instructions, str) - openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) - instructions = NOT_GIVEN + if model_request_parameters.output_mode == 'model_structured': + output_object = model_request_parameters.output_object + assert output_object is not None + text = {'format': self._map_json_schema(output_object)} + elif ( + model_request_parameters.output_mode == 'prompted_structured' and self.profile.supports_json_output + ): # pragma: no branch + text = {'format': {'type': 'json_object'}} + + # Without this trick, we'd hit this error: + # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. + # Apparently they're only checking input messages for "JSON", not instructions. + assert isinstance(instructions, str) + openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) + instructions = NOT_GIVEN sampling_settings = ( model_settings diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py new file mode 100644 index 000000000..73ad0d1d0 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Sequence +from dataclasses import dataclass +from typing import Any, Callable, Generic, Literal, Union + +from typing_extensions import TypeAliasType, TypeVar, get_args, get_origin +from typing_inspection import typing_objects +from typing_inspection.introspection import is_union_origin + +from .messages import RetryPromptPart +from .tools import RunContext + +OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) +"""Covariant type variable for the result data type of a run.""" + +T = TypeVar('T') + +T_co = TypeVar('T_co', covariant=True) + +OutputTypeOrFunction = TypeAliasType( + 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,) +) + +OutputMode = Literal['text', 'tool', 'model_structured', 'prompted_structured', 'tool_or_text'] +"""All output modes.""" +StructuredOutputMode = Literal['tool', 'model_structured', 'prompted_structured'] +"""Output modes that can be used for structured output. Used by ModelProfile.default_structured_output_mode""" + + +class ToolRetryError(Exception): + """Internal exception used to signal a `ToolRetry` message should be returned to the LLM.""" + + def __init__(self, tool_retry: RetryPromptPart): + self.tool_retry = tool_retry + super().__init__() + + +@dataclass(init=False) +class ToolOutput(Generic[OutputDataT]): + """Marker class to use a tool for output and optionally customize the tool.""" + + output: OutputTypeOrFunction[OutputDataT] + name: str | None + description: str | None + max_retries: int | None + strict: bool | None + + def __init__( + self, + type_: OutputTypeOrFunction[OutputDataT], + *, + name: str | None = None, + description: str | None = None, + max_retries: int | None = None, + strict: bool | None = None, + ): + self.output = type_ + self.name = name + self.description = description + self.max_retries = max_retries + self.strict = strict + + +@dataclass(init=False) +class ModelStructuredOutput(Generic[OutputDataT]): + """Marker class to use the model's built-in structured outputs functionality for outputs and optionally customize the name and description.""" + + outputs: Sequence[OutputTypeOrFunction[OutputDataT]] + name: str | None + description: str | None + + def __init__( + self, + type_: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, + ): + self.outputs = _flatten_output_spec(type_) + self.name = name + self.description = description + + +@dataclass(init=False) +class PromptedStructuredOutput(Generic[OutputDataT]): + """Marker class to use a prompt to tell the model what to output and optionally customize the prompt.""" + + outputs: Sequence[OutputTypeOrFunction[OutputDataT]] + name: str | None + description: str | None + template: str | None + """Template for the prompt passed to the model. + The '{schema}' placeholder will be replaced with the output JSON schema. + If not specified, the default template specified on the model's profile will be used. + """ + + def __init__( + self, + type_: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, + template: str | None = None, + ): + self.outputs = _flatten_output_spec(type_) + self.name = name + self.description = description + self.template = template + + +@dataclass +class TextOutput(Generic[OutputDataT]): + """Marker class to use text output for an output function taking a string argument.""" + + output_function: TextOutputFunction[OutputDataT] + + +def _get_union_args(tp: Any) -> tuple[Any, ...]: + """Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple.""" + if typing_objects.is_typealiastype(tp): + tp = tp.__value__ + + origin = get_origin(tp) + if is_union_origin(origin): + return get_args(tp) + else: + return () + + +def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: + outputs: Sequence[T] + if isinstance(output_spec, Sequence): + outputs = output_spec + else: + outputs = (output_spec,) + + outputs_flat: list[T] = [] + for output in outputs: + if union_types := _get_union_args(output): + outputs_flat.extend(union_types) + else: + outputs_flat.append(output) + return outputs_flat + + +OutputSpec = TypeAliasType( + 'OutputSpec', + Union[ + OutputTypeOrFunction[T_co], + ToolOutput[T_co], + ModelStructuredOutput[T_co], + PromptedStructuredOutput[T_co], + TextOutput[T_co], + Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], + ], + type_params=(T_co,), +) + +TextOutputFunction = TypeAliasType( + 'TextOutputFunction', + Union[ + Callable[[RunContext, str], Union[Awaitable[T_co], T_co]], + Callable[[str], Union[Awaitable[T_co], T_co]], + ], + type_params=(T_co,), +) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 8d87f973d..c114cfb10 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -6,8 +6,7 @@ from typing_extensions import Self -from pydantic_ai._output import StructuredOutputMode - +from ..output import StructuredOutputMode from ._json_schema import JsonSchemaTransformer @@ -17,13 +16,13 @@ class ModelProfile: supports_tools: bool = True """Whether the model supports tools.""" - supports_json_schema_response_format: bool = False - """Whether the model supports the JSON schema response format.""" - supports_json_object_response_format: bool = False - """Whether the model supports the JSON object response format.""" + supports_structured_output: bool = False + """Whether the model supports JSON schema output.""" + supports_json_output: bool = False + """Whether the model supports JSON object output.""" default_structured_output_mode: StructuredOutputMode = 'tool' """The default structured output mode to use for the model.""" - structured_output_instructions_template: str = dedent( + prompted_structured_output_template: str = dedent( """ Always respond with a JSON object that's compatible with this schema: @@ -32,7 +31,7 @@ class ModelProfile: Don't include any text or Markdown fencing before or after. """ ) - """The instructions to use for prompted JSON output. The '{schema}' placeholder will be replaced with the JSON schema for the output.""" + """The instructions template to use for prompted structured output. The '{schema}' placeholder will be replaced with the JSON schema for the output.""" json_schema_transformer: type[JsonSchemaTransformer] | None = None """The transformer to use to make JSON schemas for tools and structured output compatible with the model.""" diff --git a/pydantic_ai_slim/pydantic_ai/profiles/google.py b/pydantic_ai_slim/pydantic_ai/profiles/google.py index de85ea8d5..652249aca 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/google.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/google.py @@ -12,8 +12,8 @@ def google_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Google model.""" return ModelProfile( json_schema_transformer=GoogleJsonSchemaTransformer, - supports_json_schema_response_format=True, - supports_json_object_response_format=True, + supports_structured_output=True, + supports_json_output=True, ) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index 78be9814a..ea4c474b8 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -26,12 +26,12 @@ def openai_model_profile(model_name: str) -> ModelProfile: """Get the model profile for an OpenAI model.""" is_reasoning_model = model_name.startswith('o') # The JSON schema response format is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later. - # We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `structured_text` is only used - # when the user specifically uses the StructuredTextOutput marker, so an error from the API is acceptable. + # We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `model_structured` is only used + # when the user specifically uses the ModelStructuredOutput marker, so an error from the API is acceptable. return OpenAIModelProfile( json_schema_transformer=OpenAIJsonSchemaTransformer, - supports_json_schema_response_format=True, - supports_json_object_response_format=True, + supports_structured_output=True, + supports_json_output=True, openai_supports_sampling_settings=not is_reasoning_model, ) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index e8131e247..fa6fcd1b2 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -12,19 +12,19 @@ from . import _utils, exceptions, messages as _messages, models from ._output import ( - OutputDataT, OutputDataT_inv, OutputSchema, OutputValidator, OutputValidatorFunc, PlainTextOutputSchema, - StructuredTextOutput, - TextOutput, TextOutputSchema, - ToolOutput, ToolOutputSchema, ) from .messages import AgentStreamEvent, FinalResultEvent +from .output import ( + OutputDataT, + ToolOutput, +) from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits @@ -32,8 +32,6 @@ 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', - 'TextOutput', - 'StructuredTextOutput', 'OutputValidatorFunc', ) diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output.yaml rename to tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output.yaml diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output_multiple.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_anthropic/test_anthropic_structured_text_output_multiple.yaml rename to tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_model_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_gemini/test_gemini_structured_text_output.yaml rename to tests/models/cassettes/test_gemini/test_gemini_model_structured_output.yaml diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_model_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_gemini/test_gemini_structured_text_output_multiple.yaml rename to tests/models/cassettes/test_gemini/test_gemini_model_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions.yaml rename to tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output.yaml diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_multiple.yaml rename to tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_with_tools.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_with_tools.yaml similarity index 100% rename from tests/models/cassettes/test_gemini/test_gemini_structured_text_output_with_instructions_with_tools.yaml rename to tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_with_tools.yaml diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output.yaml b/tests/models/cassettes/test_google/test_google_model_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_google/test_google_structured_text_output.yaml rename to tests/models/cassettes/test_google/test_google_model_structured_output.yaml diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_model_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_google/test_google_structured_text_output_multiple.yaml rename to tests/models/cassettes/test_google/test_google_model_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions.yaml b/tests/models/cassettes/test_google/test_google_prompted_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions.yaml rename to tests/models/cassettes/test_google/test_google_prompted_structured_output.yaml diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_multiple.yaml b/tests/models/cassettes/test_google/test_google_prompted_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_multiple.yaml rename to tests/models/cassettes/test_google/test_google_prompted_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_with_tools.yaml b/tests/models/cassettes/test_google/test_google_prompted_structured_output_with_tools.yaml similarity index 100% rename from tests/models/cassettes/test_google/test_google_structured_text_output_with_instructions_with_tools.yaml rename to tests/models/cassettes/test_google/test_google_prompted_structured_output_with_tools.yaml diff --git a/tests/models/cassettes/test_openai/test_openai_structured_text_output.yaml b/tests/models/cassettes/test_openai/test_openai_model_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_openai/test_openai_structured_text_output.yaml rename to tests/models/cassettes/test_openai/test_openai_model_structured_output.yaml diff --git a/tests/models/cassettes/test_openai/test_openai_structured_text_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_model_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_openai/test_openai_structured_text_output_multiple.yaml rename to tests/models/cassettes/test_openai/test_openai_model_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions.yaml rename to tests/models/cassettes/test_openai/test_openai_prompted_structured_output.yaml diff --git a/tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_openai/test_openai_structured_text_output_with_instructions_multiple.yaml rename to tests/models/cassettes/test_openai/test_openai_prompted_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_openai_responses/test_structured_text_output.yaml b/tests/models/cassettes/test_openai_responses/test_model_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_openai_responses/test_structured_text_output.yaml rename to tests/models/cassettes/test_openai_responses/test_model_structured_output.yaml diff --git a/tests/models/cassettes/test_openai_responses/test_structured_text_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_model_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_openai_responses/test_structured_text_output_multiple.yaml rename to tests/models/cassettes/test_openai_responses/test_model_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_structured_output.yaml similarity index 100% rename from tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions.yaml rename to tests/models/cassettes/test_openai_responses/test_prompted_structured_output.yaml diff --git a/tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_structured_output_multiple.yaml similarity index 100% rename from tests/models/cassettes/test_openai_responses/test_structured_text_output_with_instructions_multiple.yaml rename to tests/models/cassettes/test_openai_responses/test_prompted_structured_output_multiple.yaml diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 73207dfad..617b5b2cf 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -14,6 +14,7 @@ from pydantic import BaseModel from pydantic_ai import Agent, ModelHTTPError, ModelRetry +from pydantic_ai.exceptions import UserError from pydantic_ai.messages import ( BinaryContent, DocumentUrl, @@ -27,7 +28,8 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput, Usage +from pydantic_ai.output import ModelStructuredOutput, PromptedStructuredOutput, TextOutput, ToolOutput +from pydantic_ai.result import Usage from pydantic_ai.settings import ModelSettings from ..conftest import IsDatetime, IsNow, IsStr, TestEnv, raise_if_exception, try_import @@ -1301,14 +1303,14 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_anthropic_structured_text_output(allow_model_requests: None, anthropic_api_key: str): +async def test_anthropic_prompted_structured_output(allow_model_requests: None, anthropic_api_key: str): m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) + agent = Agent(m, output_type=PromptedStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -1396,7 +1398,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_anthropic_structured_text_output_multiple(allow_model_requests: None, anthropic_api_key: str): +async def test_anthropic_prompted_structured_output_multiple(allow_model_requests: None, anthropic_api_key: str): m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) class CityLocation(BaseModel): @@ -1407,7 +1409,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=PromptedStructuredOutput([CityLocation, CountryLanguage])) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -1453,3 +1455,17 @@ class CountryLanguage(BaseModel): ), ] ) + + +@pytest.mark.vcr() +async def test_anthropic_model_structured_output(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ModelStructuredOutput(CityLocation)) + + with pytest.raises(UserError, match='Structured output is not supported by the model.'): + await agent.run('What is the largest city in the user country?') diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 38839b8f4..954dacb26 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -52,8 +52,9 @@ _GeminiTools, _GeminiUsageMetaData, ) +from pydantic_ai.output import ModelStructuredOutput, PromptedStructuredOutput, TextOutput, ToolOutput from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput, Usage +from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition from ..conftest import ClientWithHandler, IsDatetime, IsNow, IsStr, TestEnv @@ -1663,25 +1664,25 @@ def upcase(text: str) -> str: @pytest.mark.vcr() -async def test_gemini_structured_text_output_with_tools(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_model_structured_output_with_tools(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) + agent = Agent(m, output_type=ModelStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: return 'Mexico' # pragma: no cover - with pytest.raises(UserError, match='Gemini does not support JSON schema output and tools at the same time.'): + with pytest.raises(UserError, match='Gemini does not support structured output and tools at the same time.'): await agent.run('What is the largest city in the user country?') @pytest.mark.vcr() -async def test_gemini_structured_text_output(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_model_structured_output(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): @@ -1690,7 +1691,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) + agent = Agent(m, output_type=ModelStructuredOutput(CityLocation)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -1733,7 +1734,7 @@ class CityLocation(BaseModel): @pytest.mark.vcr() -async def test_gemini_structured_text_output_multiple(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_model_structured_output_multiple(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): @@ -1744,7 +1745,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=ModelStructuredOutput([CityLocation, CountryLanguage])) result = await agent.run('What is the primarily language spoken in Mexico?') assert result.output == snapshot(CountryLanguage(country='Mexico', language='Spanish')) @@ -1792,14 +1793,14 @@ class CountryLanguage(BaseModel): @pytest.mark.vcr() -async def test_gemini_structured_text_output_with_instructions(allow_model_requests: None, gemini_api_key: str): +async def test_gemini_prompted_structured_output(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput(CityLocation)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -1844,16 +1845,14 @@ class CityLocation(BaseModel): @pytest.mark.vcr() -async def test_gemini_structured_text_output_with_instructions_with_tools( - allow_model_requests: None, gemini_api_key: str -): +async def test_gemini_prompted_structured_output_with_tools(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.5-pro-preview-05-06', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -1931,9 +1930,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_gemini_structured_text_output_with_instructions_multiple( - allow_model_requests: None, gemini_api_key: str -): +async def test_gemini_prompted_structured_output_multiple(allow_model_requests: None, gemini_api_key: str): m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) class CityLocation(BaseModel): @@ -1944,7 +1941,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage], instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput([CityLocation, CountryLanguage])) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 8dff0c8f6..d74edbd58 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -35,7 +35,8 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput, Usage +from pydantic_ai.output import ModelStructuredOutput, PromptedStructuredOutput, TextOutput, ToolOutput +from pydantic_ai.result import Usage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -961,24 +962,24 @@ async def get_user_country() -> str: ) -async def test_google_structured_text_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_model_structured_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) + agent = Agent(m, output_type=ModelStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: return 'Mexico' # pragma: no cover - with pytest.raises(UserError, match='Gemini does not support JSON schema output and tools at the same time.'): + with pytest.raises(UserError, match='Gemini does not support structured output and tools at the same time.'): await agent.run('What is the largest city in the user country?') -async def test_google_structured_text_output(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_model_structured_output(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): @@ -987,7 +988,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) + agent = Agent(m, output_type=ModelStructuredOutput(CityLocation)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -1028,7 +1029,7 @@ class CityLocation(BaseModel): ) -async def test_google_structured_text_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): +async def test_google_model_structured_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): @@ -1039,7 +1040,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=ModelStructuredOutput([CityLocation, CountryLanguage])) result = await agent.run('What is the primarily language spoken in Mexico?') assert result.output == snapshot(CountryLanguage(country='Mexico', language='Spanish')) @@ -1085,16 +1086,14 @@ class CountryLanguage(BaseModel): ) -async def test_google_structured_text_output_with_instructions( - allow_model_requests: None, google_provider: GoogleProvider -): +async def test_google_prompted_structured_output(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput(CityLocation)) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) @@ -1137,7 +1136,7 @@ class CityLocation(BaseModel): ) -async def test_google_structured_text_output_with_instructions_with_tools( +async def test_google_prompted_structured_output_with_tools( allow_model_requests: None, google_provider: GoogleProvider ): m = GoogleModel('gemini-2.5-pro-preview-05-06', provider=google_provider) @@ -1146,7 +1145,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -1229,9 +1228,7 @@ async def get_user_country() -> str: ) -async def test_google_structured_text_output_with_instructions_multiple( - allow_model_requests: None, google_provider: GoogleProvider -): +async def test_google_prompted_structured_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.0-flash', provider=google_provider) class CityLocation(BaseModel): @@ -1242,7 +1239,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage], instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput([CityLocation, CountryLanguage])) result = await agent.run('What is the largest city in Mexico?') assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index d14ddb470..890c44d08 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -30,11 +30,12 @@ UserPromptPart, ) from pydantic_ai.models.gemini import GeminiModel +from pydantic_ai.output import ModelStructuredOutput, PromptedStructuredOutput, TextOutput, ToolOutput from pydantic_ai.profiles import ModelProfile from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput, Usage +from pydantic_ai.result import Usage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition @@ -538,7 +539,7 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): assert result.is_complete -async def test_stream_structured_json_schema_output(allow_model_requests: None): +async def test_stream_model_structured_output(allow_model_requests: None): stream = [ chunk([]), text_chunk('{"first": "One'), @@ -548,7 +549,7 @@ async def test_stream_structured_json_schema_output(allow_model_requests: None): ] mock_client = MockOpenAI.create_mock_stream(stream) m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) - agent = Agent(m, output_type=StructuredTextOutput(MyTypedDict)) + agent = Agent(m, output_type=ModelStructuredOutput(MyTypedDict)) async with agent.run_stream('') as result: assert not result.is_complete @@ -1962,7 +1963,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_openai_structured_text_output(allow_model_requests: None, openai_api_key: str): +async def test_openai_model_structured_output(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -1971,7 +1972,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) + agent = Agent(m, output_type=ModelStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -2045,7 +2046,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_openai_structured_text_output_multiple(allow_model_requests: None, openai_api_key: str): +async def test_openai_model_structured_output_multiple(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -2056,7 +2057,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=ModelStructuredOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: @@ -2134,14 +2135,14 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_openai_structured_text_output_with_instructions(allow_model_requests: None, openai_api_key: str): +async def test_openai_prompted_structured_output(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -2229,9 +2230,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_openai_structured_text_output_with_instructions_multiple( - allow_model_requests: None, openai_api_key: str -): +async def test_openai_prompted_structured_output_multiple(allow_model_requests: None, openai_api_key: str): m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -2242,7 +2241,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage], instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index 68e405358..6938adc50 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -20,8 +20,8 @@ ToolReturnPart, UserPromptPart, ) +from pydantic_ai.output import ModelStructuredOutput, PromptedStructuredOutput, TextOutput, ToolOutput from pydantic_ai.profiles.openai import openai_model_profile -from pydantic_ai.result import StructuredTextOutput, TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import Usage @@ -678,7 +678,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_structured_text_output(allow_model_requests: None, openai_api_key: str): +async def test_model_structured_output(allow_model_requests: None, openai_api_key: str): m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -687,7 +687,7 @@ class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) + agent = Agent(m, output_type=ModelStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -748,7 +748,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_structured_text_output_multiple(allow_model_requests: None, openai_api_key: str): +async def test_model_structured_output_multiple(allow_model_requests: None, openai_api_key: str): m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -759,7 +759,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage])) + agent = Agent(m, output_type=ModelStructuredOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: @@ -824,14 +824,14 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_structured_text_output_with_instructions(allow_model_requests: None, openai_api_key: str): +async def test_prompted_structured_output(allow_model_requests: None, openai_api_key: str): m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): city: str country: str - agent = Agent(m, output_type=StructuredTextOutput(CityLocation, instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput(CityLocation)) @agent.tool_plain async def get_user_country() -> str: @@ -906,7 +906,7 @@ async def get_user_country() -> str: @pytest.mark.vcr() -async def test_structured_text_output_with_instructions_multiple(allow_model_requests: None, openai_api_key: str): +async def test_prompted_structured_output_multiple(allow_model_requests: None, openai_api_key: str): m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) class CityLocation(BaseModel): @@ -917,7 +917,7 @@ class CountryLanguage(BaseModel): country: str language: str - agent = Agent(m, output_type=StructuredTextOutput([CityLocation, CountryLanguage], instructions=True)) + agent = Agent(m, output_type=PromptedStructuredOutput([CityLocation, CountryLanguage])) @agent.tool_plain async def get_user_country() -> str: diff --git a/tests/test_agent.py b/tests/test_agent.py index bcccb5517..de5ea7d3e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -14,11 +14,11 @@ from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages from pydantic_ai._output import ( + ModelStructuredOutput, OutputSpec, - StructuredTextOutput, + PromptedStructuredOutput, TextOutput, TextOutputSchema, - ToolOutput, ToolOutputSchema, ) from pydantic_ai.agent import AgentRunResult @@ -38,6 +38,7 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.output import ToolOutput from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition @@ -1269,9 +1270,7 @@ def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: tool_model = FunctionModel(hello, profile=ModelProfile(default_structured_output_mode='tool')) structured_text_model = FunctionModel( hello, - profile=ModelProfile( - supports_json_schema_response_format=True, default_structured_output_mode='structured_text' - ), + profile=ModelProfile(supports_structured_output=True, default_structured_output_mode='prompted_structured'), ) class Foo(BaseModel): @@ -1280,11 +1279,11 @@ class Foo(BaseModel): tool_agent = Agent(tool_model, output_type=Foo) assert tool_agent._output_schema.mode == 'tool' # type: ignore[reportPrivateUsage] - structured_text_agent = Agent(structured_text_model, output_type=Foo) - assert structured_text_agent._output_schema.mode == 'structured_text' # type: ignore[reportPrivateUsage] + structured_agent = Agent(structured_text_model, output_type=Foo) + assert structured_agent._output_schema.mode == 'prompted_structured' # type: ignore[reportPrivateUsage] -def test_output_type_structured_text(): +def test_prompted_structured_output(): def return_city_location(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: text = CityLocation(city='Mexico City', country='Mexico').model_dump_json() return ModelResponse(parts=[TextPart(content=text)]) @@ -1299,8 +1298,8 @@ class CityLocation(BaseModel): agent = Agent( m, - output_type=StructuredTextOutput( - CityLocation, name='City & Country', description='Description from StructuredTextOutput' + output_type=PromptedStructuredOutput( + CityLocation, name='City & Country', description='Description from PromptedStructuredOutput' ), ) @@ -1318,7 +1317,7 @@ class CityLocation(BaseModel): instructions="""\ Always respond with a JSON object that's compatible with this schema: -{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "City & Country", "type": "object", "description": "Description from StructuredTextOutput. Description from docstring."} +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "City & Country", "type": "object", "description": "Description from PromptedStructuredOutput. Description from docstring."} Don't include any text or Markdown fencing before or after.\ """, @@ -1333,7 +1332,7 @@ class CityLocation(BaseModel): ) -def test_output_type_structured_text_with_custom_instructions(): +def test_prompted_structured_output_with_template(): def return_foo(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: text = Foo(bar='baz').model_dump_json() return ModelResponse(parts=[TextPart(content=text)]) @@ -1343,7 +1342,7 @@ def return_foo(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: class Foo(BaseModel): bar: str - agent = Agent(m, output_type=StructuredTextOutput(Foo, instructions='Gimme some JSON:')) + agent = Agent(m, output_type=PromptedStructuredOutput(Foo, template='Gimme some JSON:')) result = agent.run_sync('What is the capital of Mexico?') assert result.output == snapshot(Foo(bar='baz')) @@ -1372,7 +1371,7 @@ class Foo(BaseModel): ) -def test_output_type_structured_text_with_defs(): +def test_prompted_structured_output_with_defs(): class Foo(BaseModel): """Foo description""" @@ -1408,7 +1407,7 @@ def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent( m, - output_type=StructuredTextOutput( + output_type=PromptedStructuredOutput( [FooBar, FooBaz], name='FooBar or FooBaz', description='FooBar or FooBaz description' ), ) @@ -1446,7 +1445,7 @@ def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: ) -def test_output_type_json_schema(): +def test_model_structured_output(): def return_city_location(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: if len(messages) == 1: text = '{"city": "Mexico City"}' @@ -1454,7 +1453,7 @@ def return_city_location(messages: list[ModelMessage], _info: AgentInfo) -> Mode text = '{"city": "Mexico City", "country": "Mexico"}' return ModelResponse(parts=[TextPart(content=text)]) - m = FunctionModel(return_city_location, profile=ModelProfile(supports_json_schema_response_format=True)) + m = FunctionModel(return_city_location, profile=ModelProfile(supports_structured_output=True)) class CityLocation(BaseModel): city: str @@ -1462,7 +1461,7 @@ class CityLocation(BaseModel): agent = Agent( m, - output_type=StructuredTextOutput(CityLocation), + output_type=ModelStructuredOutput(CityLocation), ) result = agent.run_sync('What is the capital of Mexico?') @@ -1509,7 +1508,7 @@ class CityLocation(BaseModel): ) -def test_output_type_structured_text_function_with_retry(): +def test_prompted_structured_output_function_with_retry(): class Weather(BaseModel): temperature: float description: str @@ -1529,7 +1528,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content=args_json)]) - agent = Agent(FunctionModel(call_tool), output_type=StructuredTextOutput(get_weather, instructions=True)) + agent = Agent(FunctionModel(call_tool), output_type=PromptedStructuredOutput(get_weather)) result = agent.run_sync('New York City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert result.all_messages() == snapshot( @@ -3112,11 +3111,11 @@ class Foo(BaseModel): def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart('hello')]) # pragma: no cover - model = FunctionModel(hello, profile=ModelProfile(supports_tools=False, supports_json_schema_response_format=False)) + model = FunctionModel(hello, profile=ModelProfile(supports_tools=False, supports_structured_output=False)) - agent = Agent(model, output_type=StructuredTextOutput(Foo, instructions=False)) + agent = Agent(model, output_type=ModelStructuredOutput(Foo)) - with pytest.raises(UserError, match='Structured output without using instructions is not supported by the model.'): + with pytest.raises(UserError, match='Structured output is not supported by the model.'): agent.run_sync('Hello') agent = Agent(model, output_type=ToolOutput(Foo)) diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c652f3abe..56ce279e7 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -13,7 +13,6 @@ from pydantic import BaseModel from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages -from pydantic_ai._output import StructuredTextOutput, TextOutput from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( FunctionToolCallEvent, @@ -29,6 +28,7 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.output import PromptedStructuredOutput, TextOutput from pydantic_ai.result import AgentStream, FinalResult, Usage from pydantic_graph import End @@ -1059,14 +1059,14 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf ) -async def test_stream_output_type_structured_text(): +async def test_stream_prompted_structured_output(): class CityLocation(BaseModel): city: str country: str | None = None m = TestModel(custom_output_text='{"city": "Mexico City", "country": "Mexico"}') - agent = Agent(m, output_type=StructuredTextOutput(CityLocation)) + agent = Agent(m, output_type=PromptedStructuredOutput(CityLocation)) async with agent.run_stream('') as result: assert not result.is_complete diff --git a/tests/test_tools.py b/tests/test_tools.py index 8be113610..e3eeedde0 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -11,10 +11,11 @@ from pydantic_core import PydanticSerializationError, core_schema from typing_extensions import TypedDict -from pydantic_ai import Agent, RunContext, Tool, ToolOutput, UserError +from pydantic_ai import Agent, RunContext, Tool, UserError from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.output import ToolOutput from pydantic_ai.tools import ToolDefinition diff --git a/tests/typed_agent.py b/tests/typed_agent.py index eaa2c4fa8..941bbf987 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -9,8 +9,8 @@ from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool -from pydantic_ai._output import TextOutput, ToolOutput from pydantic_ai.agent import AgentRunResult +from pydantic_ai.output import TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition # Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True From 71d165599d1b27779574109f371e81e87e121bbe Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 22:20:12 +0000 Subject: [PATCH 37/90] Fix WrapperModel.profile --- pydantic_ai_slim/pydantic_ai/models/wrapper.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index 1d37e320d..07d319ec4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -3,9 +3,11 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass +from functools import cached_property from typing import Any from ..messages import ModelMessage, ModelResponse +from ..profiles import ModelProfile from ..settings import ModelSettings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model @@ -47,5 +49,9 @@ def model_name(self) -> str: def system(self) -> str: return self.wrapped.system + @cached_property + def profile(self) -> ModelProfile: + return self.wrapped.profile + def __getattr__(self, item: str): return getattr(self.wrapped, item) # pragma: no cover From 8c041441f297768cc8e452ec32d018bff5f06123 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 22:25:01 +0000 Subject: [PATCH 38/90] Update output modes docs --- docs/api/output.md | 11 ++ docs/api/result.md | 5 - docs/output.md | 116 +++++++++++++----- mkdocs.yml | 1 + pydantic_ai_slim/pydantic_ai/agent.py | 2 +- .../pydantic_ai/profiles/openai.py | 4 +- tests/test_examples.py | 7 +- 7 files changed, 106 insertions(+), 40 deletions(-) create mode 100644 docs/api/output.md diff --git a/docs/api/output.md b/docs/api/output.md new file mode 100644 index 000000000..b3dffe6bf --- /dev/null +++ b/docs/api/output.md @@ -0,0 +1,11 @@ +# `pydantic_ai.output` + +::: pydantic_ai.output + options: + inherited_members: true + members: + - OutputDataT + - ToolOutput + - ModelStructuredOutput + - PromptedStructuredOutput + - TextOutput diff --git a/docs/api/result.md b/docs/api/result.md index 85593a221..d07778e95 100644 --- a/docs/api/result.md +++ b/docs/api/result.md @@ -4,9 +4,4 @@ options: inherited_members: true members: - - OutputDataT - StreamedRunResult - - ToolOutput - - ModelStructuredOutput - - PromptedStructuredOutput - - TextOutput diff --git a/docs/output.md b/docs/output.md index 0151064bf..8119c7a81 100644 --- a/docs/output.md +++ b/docs/output.md @@ -133,8 +133,8 @@ from typing import Union from pydantic import BaseModel from pydantic_ai import Agent, ModelRetry, RunContext -from pydantic_ai.output import ToolRetryError from pydantic_ai.exceptions import UnexpectedModelBehavior +from pydantic_ai.output import ToolRetryError class Row(BaseModel): @@ -224,38 +224,55 @@ print(result.output) """ result = router_agent.run_sync('Select all pets') -print(result.output) +print(repr(result.output)) """ -explanation = "The requested table 'pets' does not exist in the database. The only available table is 'capital_cities', which does not contain data about pets." +RouterFailure(explanation="The requested table 'pets' does not exist in the database. The only available table is 'capital_cities', which does not contain data about pets.") """ result = router_agent.run_sync('How do I fly from Amsterdam to Mexico City?') -print(result.output) +print(repr(result.output)) """ -explanation = 'I am not equipped to provide travel information, such as flights from Amsterdam to Mexico City.' +RouterFailure(explanation='I am not equipped to provide travel information, such as flights from Amsterdam to Mexico City.') """ ``` +Note that if you provide an output function that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like to the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.output.TextOutput] marker class. If desired, this marker class can be used alongside one or more [`ToolOutput`](#tool-output) marker classes (or unmarked types or functions) in a list provided to `output_type`. + +```python +from pydantic_ai import Agent, TextOutput + + +def split_into_words(text: str) -> list[str]: + return text.split() + + +agent = Agent( + 'openai:gpt-4o', + output_type=TextOutput(split_into_words), +) +result = agent.run_sync('Who was Albert Einstein?') +print(result.output) +#> ['Albert', 'Einstein', 'was', 'a', 'German-born', 'theoretical', 'physicist.'] +``` + ### Output modes Pydantic AI implements three different methods to get a model to output structured data: -1. **Tool output**, where the output JSON schema is provided to the model as the parameters schema of a special output tool. This is the default as it's supported by virtually all models and has been shown to work very well. -2. **Structured text output** using a model API's native **"forced JSON Schema"** feature (aka "JSON Schema response format" or "Structured Outputs"), where the model is forced to only output text matching the provided JSON schema. This is currently only supported by OpenAI and Gemini and sometimes comes with restrictions (for example, Gemini cannot use tools at the same time as JSON output). -3. **Structured text output** using [**instructions**](agents.md#instructions), where the model is prompted to output text matching the provided JSON schema and it's up to the model to interpret those instructions correctly. If the model API supports the "JSON Mode" (aka "JSON Object response format") feature to force the model to output valid JSON, this is enabled, but it's still up to the model to abide by the schema. This is usable with all models, but is the least reliable approach as the model is not forced to match the schema. Pydantic AI will validate the returned structured data and tell the model to try again if validation fails, but if the model is not intelligent enough this may not be sufficient. +1. [Tool Output](#tool-output) +2. [Model Structured Output](#model-structured-output) +3. [Prompted Structured Output](#prompted-structured-output) -By default, Pydantic AI will use the tool output mode and register a separate output tool for each output type (or function). If you'd like to change the name of the output tool, pass a custom description to aid the model, or turn on or off strict mode, you can wrap the type(s) in the [`ToolOutput`][pydantic_ai.output.ToolOutput] marker class and provide the appropriate arguments. Note that by default, the description is taken from the docstring specified on a Pydantic model or output function, so specifying it using the marker class is typically not necessary. +#### Tool Output -If you'd like to use the structured text output mode, you can wrap the type(s) in the [`StructuredTextOutput`][pydantic_ai.output.StructuredTextOutput] marker class that also lets you specify a name and description if the name and docstring of the type or function are not sufficient. Additionally, it supports an `instructions` argument that is `None` by default, indicating that Pydantic AI should choose the best strategy supported by the model: forced JSON schema or instructions. You can set it to `False` to never use instructions (which will result in an error if the forced JSON schema feature is not supported by the model) or `True` to always use instructions. If you'd like to use a custom instructions template instead of the [default][pydantic_ai.profiles.ModelProfile.prompted_structured_output_template], you can pass a string with a `{schema}` placeholder that will be replaced with the actual JSON schema. +In the default Tool Output mode, the output JSON schema of each output type (or function) is provided to the model as the parameters schema of a special output tool. This is the default as it's supported by virtually all models and has been shown to work very well. -Finally, if you provide an [output function](#output-functions) that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like to the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.output.TextOutput] marker class. If desired, this marker class can be used alongside one or more `ToolOutput` marker classes (or unmarked types or functions) in a list provided to `output_type`. - -Here's an example of these 3 output mode marker classes in action: +If you'd like to change the name of the output tool, pass a custom description to aid the model, or turn on or off strict mode, you can wrap the type(s) in the [`ToolOutput`][pydantic_ai.output.ToolOutput] marker class and provide the appropriate arguments. Note that by default, the description is taken from the docstring specified on a Pydantic model or output function, so specifying it using the marker class is typically not necessary. ```python from pydantic import BaseModel -from pydantic_ai import Agent, ModelStructuredOutput, PromptedStructuredOutput, TextOutput, ToolOutput +from pydantic_ai import Agent, ToolOutput class Fruit(BaseModel): @@ -268,51 +285,88 @@ class Vehicle(BaseModel): wheels: int -class Device(BaseModel): - name: str - kind: str - - agent = Agent( 'openai:gpt-4o', output_type=[ ToolOutput(Fruit, name='return_fruit'), ToolOutput(Vehicle, name='return_vehicle'), - ToolOutput(Device, name='return_device'), ], ) result = agent.run_sync('What is a banana?') print(repr(result.output)) #> Fruit(name='banana', color='yellow') +``` + +#### Model Structured Output + +Model Structured Output mode uses a model's native "Structured Outputs" feature (aka "JSON Schema response format"), where the model is forced to only output text matching the provided JSON schema. This is currently only supported by OpenAI and Gemini and sometimes comes with restrictions. For example, Gemini cannot use tools at the same time as structured output and attempting to do so will result in an error. + +To use this mode, you can wrap the output type(s) in the [`ModelStructuredOutput`][pydantic_ai.output.ModelStructuredOutput] marker class that also lets you specify a `name` and `description` if the name and docstring of the type or function are not sufficient. + +```python +from pydantic import BaseModel + +from pydantic_ai import Agent, ModelStructuredOutput + + +class Fruit(BaseModel): + name: str + color: str + + +class Vehicle(BaseModel): + name: str + wheels: int + agent = Agent( 'openai:gpt-4o', - output_type=ModelStructuredOutput([Fruit, Vehicle, Device]), + output_type=ModelStructuredOutput([Fruit, Vehicle]), ) result = agent.run_sync('What is a Ford Explorer?') print(repr(result.output)) #> Vehicle(name='Ford Explorer', wheels=4) +``` + +#### Prompted Structured Output + +In this mode, the model is prompted to output text matching the provided JSON schema through its [instructions](agents.md#instructions) and it's up to the model to interpret those instructions correctly. This is usable with all models, but is the least reliable approach as the model is not forced to match the schema. + +If the model API supports the "JSON Mode" feature (aka "JSON Object response format") to force the model to output valid JSON, this is enabled, but it's still up to the model to abide by the schema. Pydantic AI will validate the returned structured data and tell the model to try again if validation fails, but if the model is not intelligent enough this may not be sufficient. + +To use this mode, you can wrap the output type(s) in the [`PromptedStructuredOutput`][pydantic_ai.output.PromptedStructuredOutput] marker class that also lets you specify a `name` and `description` if the name and docstring of the type or function are not sufficient. Additionally, it supports an `template` argument lets you specify a custom instructions template to be used instead of the [default][pydantic_ai.profiles.ModelProfile.prompted_structured_output_template]. + +```python +from pydantic import BaseModel + +from pydantic_ai import Agent, PromptedStructuredOutput + + +class Vehicle(BaseModel): + name: str + wheels: int + + +class Device(BaseModel): + name: str + kind: str + agent = Agent( 'openai:gpt-4o', - output_type=PromptedStructuredOutput([Fruit, Vehicle, Device]), + output_type=PromptedStructuredOutput([Vehicle, Device]), ) result = agent.run_sync('What is a MacBook?') print(repr(result.output)) #> Device(name='MacBook', kind='laptop') - -def split_into_words(text: str) -> list[str]: - return text.split() - - agent = Agent( 'openai:gpt-4o', - output_type=TextOutput(split_into_words), + output_type=PromptedStructuredOutput([Vehicle, Device], template='Gimme some JSON: {schema}'), ) -result = agent.run_sync('Who was Albert Einstein?') -print(result.output) -#> ['Albert', 'Einstein', 'was', 'a', 'German-born', 'theoretical', 'physicist.'] +result = agent.run_sync('What is a Ford Explorer?') +print(repr(result.output)) +#> Vehicle(name='Ford Explorer', wheels=4) ``` ### Output validators {#output-validator-functions} diff --git a/mkdocs.yml b/mkdocs.yml index 70ea19487..e5063ac83 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -63,6 +63,7 @@ nav: - api/agent.md - api/tools.md - api/common_tools.md + - api/output.md - api/result.md - api/messages.md - api/exceptions.md diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 4d58f83f4..5debf3a8f 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -93,7 +93,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] - and the result data type they return, [`OutputDataT`][pydantic_ai.result.OutputDataT]. + and the result data type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index ea4c474b8..4f43e6211 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -25,9 +25,9 @@ class OpenAIModelProfile(ModelProfile): def openai_model_profile(model_name: str) -> ModelProfile: """Get the model profile for an OpenAI model.""" is_reasoning_model = model_name.startswith('o') - # The JSON schema response format is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later. + # Structured Outputs (output mode 'model_structured') is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later. # We leave it in here for all models because the `default_structured_output_mode` is `'tool'`, so `model_structured` is only used - # when the user specifically uses the ModelStructuredOutput marker, so an error from the API is acceptable. + # when the user specifically uses the `ModelStructuredOutput` marker, so an error from the API is acceptable. return OpenAIModelProfile( json_schema_transformer=OpenAIJsonSchemaTransformer, supports_structured_output=True, diff --git a/tests/test_examples.py b/tests/test_examples.py index 745dfd371..d0a2706ad 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -764,7 +764,12 @@ def raise_http_error(messages: list[ModelMessage], info: AgentInfo) -> ModelResp return model else: model_name = model if isinstance(model, str) else model.model_name - return FunctionModel(model_logic, stream_function=stream_model_logic, model_name=model_name) + return FunctionModel( + model_logic, + stream_function=stream_model_logic, + model_name=model_name, + profile=model.profile if isinstance(model, Model) else None, + ) def mock_group_by_temporal(aiter: Any, soft_max_interval: float | None) -> Any: From d78b5f7319b82e65ea1f88ec8c73e69094b4a75d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 22:34:42 +0000 Subject: [PATCH 39/90] Add examples to output mode marker docstrings --- docs/output.md | 21 +++- pydantic_ai_slim/pydantic_ai/output.py | 130 ++++++++++++++++++++++++- 2 files changed, 143 insertions(+), 8 deletions(-) diff --git a/docs/output.md b/docs/output.md index 8119c7a81..499832b3b 100644 --- a/docs/output.md +++ b/docs/output.md @@ -236,6 +236,8 @@ RouterFailure(explanation='I am not equipped to provide travel information, such """ ``` +#### Text output + Note that if you provide an output function that takes a string, Pydantic AI will by default create an output tool like for any other output function. If instead you'd like to the model to provide the string using plain text output, you can wrap the function in the [`TextOutput`][pydantic_ai.output.TextOutput] marker class. If desired, this marker class can be used alongside one or more [`ToolOutput`](#tool-output) marker classes (or unmarked types or functions) in a list provided to `output_type`. ```python @@ -321,7 +323,11 @@ class Vehicle(BaseModel): agent = Agent( 'openai:gpt-4o', - output_type=ModelStructuredOutput([Fruit, Vehicle]), + output_type=ModelStructuredOutput( + [Fruit, Vehicle], + name='Fruit or vehicle', + description='Return a fruit or vehicle.' + ), ) result = agent.run_sync('What is a Ford Explorer?') print(repr(result.output)) @@ -354,7 +360,11 @@ class Device(BaseModel): agent = Agent( 'openai:gpt-4o', - output_type=PromptedStructuredOutput([Vehicle, Device]), + output_type=PromptedStructuredOutput( + [Vehicle, Device], + name='Vehicle or device', + description='Return a vehicle or device.' + ), ) result = agent.run_sync('What is a MacBook?') print(repr(result.output)) @@ -362,7 +372,10 @@ print(repr(result.output)) agent = Agent( 'openai:gpt-4o', - output_type=PromptedStructuredOutput([Vehicle, Device], template='Gimme some JSON: {schema}'), + output_type=PromptedStructuredOutput( + [Vehicle, Device], + template='Gimme some JSON: {schema}' + ), ) result = agent.run_sync('What is a Ford Explorer?') print(repr(result.output)) @@ -374,7 +387,7 @@ print(repr(result.output)) Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator. If you want to implement separate validation logic for different output types, it's recommended to use [output functions](#output-functions) instead, to save you from having to do `isinstance` checks inside the output validator. -If you want the model to output plain text, do your own processing or validation, and then have the agent's final output be the result of your function, it's recommended to use an [output function](#output-functions) with the `TextOutput` [output mode](#output-modes) marker class. +If you want the model to output plain text, do your own processing or validation, and then have the agent's final output be the result of your function, it's recommended to use an [output function](#output-functions) with the [`TextOutput` marker class](#text-output). Here's a simplified variant of the [SQL Generation example](examples/sql-gen.md): diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 73ad0d1d0..f9031afde 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -38,7 +38,37 @@ def __init__(self, tool_retry: RetryPromptPart): @dataclass(init=False) class ToolOutput(Generic[OutputDataT]): - """Marker class to use a tool for output and optionally customize the tool.""" + """Marker class to use a tool for output and optionally customize the tool. + + Example: + ```python + from pydantic import BaseModel + + from pydantic_ai import Agent, ToolOutput + + + class Fruit(BaseModel): + name: str + color: str + + + class Vehicle(BaseModel): + name: str + wheels: int + + + agent = Agent( + 'openai:gpt-4o', + output_type=[ + ToolOutput(Fruit, name='return_fruit'), + ToolOutput(Vehicle, name='return_vehicle'), + ], + ) + result = agent.run_sync('What is a banana?') + print(repr(result.output)) + #> Fruit(name='banana', color='yellow') + ``` + """ output: OutputTypeOrFunction[OutputDataT] name: str | None @@ -64,7 +94,38 @@ def __init__( @dataclass(init=False) class ModelStructuredOutput(Generic[OutputDataT]): - """Marker class to use the model's built-in structured outputs functionality for outputs and optionally customize the name and description.""" + """Marker class to use the model's built-in structured outputs functionality for outputs and optionally customize the name and description. + + Example: + ```python + from pydantic import BaseModel + + from pydantic_ai import Agent, ModelStructuredOutput + + + class Fruit(BaseModel): + name: str + color: str + + + class Vehicle(BaseModel): + name: str + wheels: int + + + agent = Agent( + 'openai:gpt-4o', + output_type=ModelStructuredOutput( + [Fruit, Vehicle], + name='Fruit or vehicle', + description='Return a fruit or vehicle.' + ), + ) + result = agent.run_sync('What is a Ford Explorer?') + print(repr(result.output)) + #> Vehicle(name='Ford Explorer', wheels=4) + ``` + """ outputs: Sequence[OutputTypeOrFunction[OutputDataT]] name: str | None @@ -84,7 +145,49 @@ def __init__( @dataclass(init=False) class PromptedStructuredOutput(Generic[OutputDataT]): - """Marker class to use a prompt to tell the model what to output and optionally customize the prompt.""" + """Marker class to use a prompt to tell the model what to output and optionally customize the prompt. + + Example: + ```python + from pydantic import BaseModel + + from pydantic_ai import Agent, PromptedStructuredOutput + + + class Vehicle(BaseModel): + name: str + wheels: int + + + class Device(BaseModel): + name: str + kind: str + + + agent = Agent( + 'openai:gpt-4o', + output_type=PromptedStructuredOutput( + [Vehicle, Device], + name='Vehicle or device', + description='Return a vehicle or device.' + ), + ) + result = agent.run_sync('What is a MacBook?') + print(repr(result.output)) + #> Device(name='MacBook', kind='laptop') + + agent = Agent( + 'openai:gpt-4o', + output_type=PromptedStructuredOutput( + [Vehicle, Device], + template='Gimme some JSON: {schema}' + ), + ) + result = agent.run_sync('What is a Ford Explorer?') + print(repr(result.output)) + #> Vehicle(name='Ford Explorer', wheels=4) + ``` + """ outputs: Sequence[OutputTypeOrFunction[OutputDataT]] name: str | None @@ -111,7 +214,26 @@ def __init__( @dataclass class TextOutput(Generic[OutputDataT]): - """Marker class to use text output for an output function taking a string argument.""" + """Marker class to use text output for an output function taking a string argument. + + Example: + ```python + from pydantic_ai import Agent, TextOutput + + + def split_into_words(text: str) -> list[str]: + return text.split() + + + agent = Agent( + 'openai:gpt-4o', + output_type=TextOutput(split_into_words), + ) + result = agent.run_sync('Who was Albert Einstein?') + print(result.output) + #> ['Albert', 'Einstein', 'was', 'a', 'German-born', 'theoretical', 'physicist.'] + ``` + """ output_function: TextOutputFunction[OutputDataT] From 70d1197e8128d17a591b12fb954bb41f79ee36db Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 22:50:22 +0000 Subject: [PATCH 40/90] Fix mypy type inference --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 4 +-- pydantic_ai_slim/pydantic_ai/agent.py | 26 ++++++++++---------- pydantic_ai_slim/pydantic_ai/output.py | 12 ++++----- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index b588f35d5..4c87fad38 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -17,7 +17,7 @@ from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage -from .output import OutputDataT +from .output import OutputDataT, OutputSpec from .settings import ModelSettings, merge_model_settings from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc @@ -877,7 +877,7 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], - output_type: _output.OutputSpec[OutputT], + output_type: OutputSpec[OutputT], ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 5debf3a8f..77d7711f3 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -31,7 +31,7 @@ ) from ._agent_graph import HistoryProcessor from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model -from .output import OutputDataT +from .output import OutputDataT, OutputSpec from .result import FinalResult, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -93,7 +93,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. Agents are generic in the dependency type they take [`AgentDepsT`][pydantic_ai.tools.AgentDepsT] - and the result data type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. + and the output type they return, [`OutputDataT`][pydantic_ai.output.OutputDataT]. By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. @@ -130,7 +130,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ - output_type: _output.OutputSpec[OutputDataT] + output_type: OutputSpec[OutputDataT] """ The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. """ @@ -165,7 +165,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - output_type: _output.OutputSpec[OutputDataT] = str, + output_type: OutputSpec[OutputDataT] = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] @@ -390,7 +390,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -420,7 +420,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -508,7 +508,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, - output_type: _output.OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -540,7 +540,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -793,7 +793,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -823,7 +823,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -906,7 +906,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent], *, - output_type: _output.OutputSpec[RunOutputDataT], + output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -937,7 +937,7 @@ async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputSpec[RunOutputDataT] | None = None, + output_type: OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -1654,7 +1654,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') def _prepare_output_schema( - self, output_type: _output.OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile + self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile ) -> _output.OutputSchema[RunOutputDataT]: if output_type is not None: if self._output_validators: diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index f9031afde..45d786dd0 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -16,12 +16,6 @@ T = TypeVar('T') -T_co = TypeVar('T_co', covariant=True) - -OutputTypeOrFunction = TypeAliasType( - 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,) -) - OutputMode = Literal['text', 'tool', 'model_structured', 'prompted_structured', 'tool_or_text'] """All output modes.""" StructuredOutputMode = Literal['tool', 'model_structured', 'prompted_structured'] @@ -266,6 +260,12 @@ def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: return outputs_flat +T_co = TypeVar('T_co', covariant=True) + +OutputTypeOrFunction = TypeAliasType( + 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,) +) + OutputSpec = TypeAliasType( 'OutputSpec', Union[ From 2eb7fd149ab638c2906f6f157e8a95ddbdcc08d5 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 22:52:10 +0000 Subject: [PATCH 41/90] Improve test coverage --- tests/test_agent.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index de5ea7d3e..7dc4e8b0a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1268,9 +1268,13 @@ def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover tool_model = FunctionModel(hello, profile=ModelProfile(default_structured_output_mode='tool')) - structured_text_model = FunctionModel( + model_structured_model = FunctionModel( hello, - profile=ModelProfile(supports_structured_output=True, default_structured_output_mode='prompted_structured'), + profile=ModelProfile(supports_structured_output=True, default_structured_output_mode='model_structured'), + ) + prompted_structured_model = FunctionModel( + hello, + profile=ModelProfile(default_structured_output_mode='prompted_structured'), ) class Foo(BaseModel): @@ -1279,8 +1283,11 @@ class Foo(BaseModel): tool_agent = Agent(tool_model, output_type=Foo) assert tool_agent._output_schema.mode == 'tool' # type: ignore[reportPrivateUsage] - structured_agent = Agent(structured_text_model, output_type=Foo) - assert structured_agent._output_schema.mode == 'prompted_structured' # type: ignore[reportPrivateUsage] + model_structured_agent = Agent(model_structured_model, output_type=Foo) + assert model_structured_agent._output_schema.mode == 'model_structured' # type: ignore[reportPrivateUsage] + + prompted_structured_agent = Agent(prompted_structured_model, output_type=Foo) + assert prompted_structured_agent._output_schema.mode == 'prompted_structured' # type: ignore[reportPrivateUsage] def test_prompted_structured_output(): From 9e00c32f5b726cdfce3513e90c688c5565a47fd5 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 23:19:01 +0000 Subject: [PATCH 42/90] Import cast and RunContext in _function_schema --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/_function_schema.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index fe48cf4f0..7b12f32fe 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast from opentelemetry.trace import Tracer from typing_extensions import TypeGuard, TypeVar, assert_never diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index 31e6050f3..a4b47af93 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -21,9 +21,10 @@ from ._griffe import doc_descriptions from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor +from .tools import RunContext if TYPE_CHECKING: - from .tools import DocstringFormat, ObjectJsonSchema, RunContext + from .tools import DocstringFormat, ObjectJsonSchema __all__ = ('function_schema',) @@ -279,6 +280,4 @@ def _build_schema( def _is_call_ctx(annotation: Any) -> bool: """Return whether the annotation is the `RunContext` class, parameterized or not.""" - from .tools import RunContext - return annotation is RunContext or get_origin(annotation) is RunContext From 7de3c0d2ea036f9e6cc0a8ff544c362fbee1ecc2 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 23:26:39 +0000 Subject: [PATCH 43/90] Move RunContext and AgentDepsT into their own module to solve circular import --- pydantic_ai_slim/pydantic_ai/_cli.py | 5 +- .../pydantic_ai/_function_schema.py | 2 +- pydantic_ai_slim/pydantic_ai/_output.py | 3 +- .../pydantic_ai/_system_prompt.py | 3 +- pydantic_ai_slim/pydantic_ai/result.py | 2 +- pydantic_ai_slim/pydantic_ai/run_context.py | 56 +++++++++++++++++++ pydantic_ai_slim/pydantic_ai/tools.py | 49 +--------------- 7 files changed, 66 insertions(+), 54 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/run_context.py diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index feb007327..fa974349a 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -14,14 +14,13 @@ from typing_inspection.introspection import get_literal_values -from pydantic_ai.output import OutputDataT -from pydantic_ai.tools import AgentDepsT - from . import __version__ from .agent import Agent from .exceptions import UserError from .messages import ModelMessage from .models import KnownModelName, infer_model +from .output import OutputDataT +from .run_context import AgentDepsT try: import argcomplete diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index a4b47af93..9ed8bfdd4 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -21,7 +21,7 @@ from ._griffe import doc_descriptions from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor -from .tools import RunContext +from .run_context import RunContext if TYPE_CHECKING: from .tools import DocstringFormat, ObjectJsonSchema diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index f0a7db7fe..27ed7455f 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -27,7 +27,8 @@ ToolRetryError, _flatten_output_spec, # pyright: ignore[reportPrivateUsage] ) -from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition +from .run_context import AgentDepsT, RunContext +from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition if TYPE_CHECKING: from .profiles import ModelProfile diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index df2b93e7e..7a3b1c9c3 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -6,7 +6,8 @@ from typing import Any, Callable, Generic, cast from . import _utils -from .tools import AgentDepsT, RunContext, SystemPromptFunc +from .run_context import AgentDepsT, RunContext +from .tools import SystemPromptFunc @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index fa6fcd1b2..3e0542b26 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -25,7 +25,7 @@ OutputDataT, ToolOutput, ) -from .tools import AgentDepsT, RunContext +from .run_context import AgentDepsT, RunContext from .usage import Usage, UsageLimits __all__ = ( diff --git a/pydantic_ai_slim/pydantic_ai/run_context.py b/pydantic_ai_slim/pydantic_ai/run_context.py new file mode 100644 index 000000000..bb7f47420 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/run_context.py @@ -0,0 +1,56 @@ +from __future__ import annotations as _annotations + +import dataclasses +from collections.abc import Sequence +from dataclasses import field +from typing import TYPE_CHECKING, Generic + +from typing_extensions import TypeVar + +from . import _utils, messages as _messages + +if TYPE_CHECKING: + from .models import Model + from .result import Usage + +AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) +"""Type variable for agent dependencies.""" + + +@dataclasses.dataclass(repr=False) +class RunContext(Generic[AgentDepsT]): + """Information about the current call.""" + + deps: AgentDepsT + """Dependencies for the agent.""" + model: Model + """The model used in this run.""" + usage: Usage + """LLM usage associated with the run.""" + prompt: str | Sequence[_messages.UserContent] | None + """The original user prompt passed to the run.""" + messages: list[_messages.ModelMessage] = field(default_factory=list) + """Messages exchanged in the conversation so far.""" + tool_call_id: str | None = None + """The ID of the tool call.""" + tool_name: str | None = None + """Name of the tool being called.""" + retry: int = 0 + """Number of retries so far.""" + run_step: int = 0 + """The current step in the run.""" + + def replace_with( + self, + retry: int | None = None, + tool_name: str | None | _utils.Unset = _utils.UNSET, + ) -> RunContext[AgentDepsT]: + # Create a new `RunContext` a new `retry` value and `tool_name`. + kwargs = {} + if retry is not None: + kwargs['retry'] = retry + if tool_name is not _utils.UNSET: # pragma: no branch + kwargs['tool_name'] = tool_name + return dataclasses.replace(self, **kwargs) + + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index bb3401519..e3c97bed8 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -4,7 +4,7 @@ import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union +from typing import Any, Callable, Generic, Literal, Union from opentelemetry.trace import Tracer from pydantic import ValidationError @@ -14,10 +14,7 @@ from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry, UnexpectedModelBehavior - -if TYPE_CHECKING: - from .models import Model - from .result import Usage +from .run_context import AgentDepsT, RunContext __all__ = ( 'AgentDepsT', @@ -35,48 +32,6 @@ 'ToolDefinition', ) -AgentDepsT = TypeVar('AgentDepsT', default=None, contravariant=True) -"""Type variable for agent dependencies.""" - - -@dataclasses.dataclass(repr=False) -class RunContext(Generic[AgentDepsT]): - """Information about the current call.""" - - deps: AgentDepsT - """Dependencies for the agent.""" - model: Model - """The model used in this run.""" - usage: Usage - """LLM usage associated with the run.""" - prompt: str | Sequence[_messages.UserContent] | None - """The original user prompt passed to the run.""" - messages: list[_messages.ModelMessage] = field(default_factory=list) - """Messages exchanged in the conversation so far.""" - tool_call_id: str | None = None - """The ID of the tool call.""" - tool_name: str | None = None - """Name of the tool being called.""" - retry: int = 0 - """Number of retries so far.""" - run_step: int = 0 - """The current step in the run.""" - - def replace_with( - self, - retry: int | None = None, - tool_name: str | None | _utils.Unset = _utils.UNSET, - ) -> RunContext[AgentDepsT]: - # Create a new `RunContext` a new `retry` value and `tool_name`. - kwargs = {} - if retry is not None: - kwargs['retry'] = retry - if tool_name is not _utils.UNSET: # pragma: no branch - kwargs['tool_name'] = tool_name - return dataclasses.replace(self, **kwargs) - - __repr__ = _utils.dataclasses_no_defaults_repr - ToolParams = ParamSpec('ToolParams', default=...) """Retrieval function param spec.""" From 4029facbc0b3b277982ddb11acfb5e80d9aff4fd Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 17 Jun 2025 23:32:29 +0000 Subject: [PATCH 44/90] Make _run_context module private, RunContext can be accessed through the tools module --- pydantic_ai_slim/pydantic_ai/_cli.py | 2 +- pydantic_ai_slim/pydantic_ai/_function_schema.py | 2 +- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- .../pydantic_ai/{run_context.py => _run_context.py} | 0 pydantic_ai_slim/pydantic_ai/_system_prompt.py | 2 +- pydantic_ai_slim/pydantic_ai/result.py | 2 +- pydantic_ai_slim/pydantic_ai/tools.py | 2 +- 7 files changed, 6 insertions(+), 6 deletions(-) rename pydantic_ai_slim/pydantic_ai/{run_context.py => _run_context.py} (100%) diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index fa974349a..a0e5361ea 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -15,12 +15,12 @@ from typing_inspection.introspection import get_literal_values from . import __version__ +from ._run_context import AgentDepsT from .agent import Agent from .exceptions import UserError from .messages import ModelMessage from .models import KnownModelName, infer_model from .output import OutputDataT -from .run_context import AgentDepsT try: import argcomplete diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index 9ed8bfdd4..7aa7c9cca 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -20,8 +20,8 @@ from typing_extensions import Concatenate, ParamSpec, TypeIs, TypeVar, get_origin from ._griffe import doc_descriptions +from ._run_context import RunContext from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor -from .run_context import RunContext if TYPE_CHECKING: from .tools import DocstringFormat, ObjectJsonSchema diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 27ed7455f..359153272 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -12,6 +12,7 @@ from typing_extensions import TypedDict, TypeVar, assert_never from . import _function_schema, _utils, messages as _messages +from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, UserError from .output import ( ModelStructuredOutput, @@ -27,7 +28,6 @@ ToolRetryError, _flatten_output_spec, # pyright: ignore[reportPrivateUsage] ) -from .run_context import AgentDepsT, RunContext from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition if TYPE_CHECKING: diff --git a/pydantic_ai_slim/pydantic_ai/run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py similarity index 100% rename from pydantic_ai_slim/pydantic_ai/run_context.py rename to pydantic_ai_slim/pydantic_ai/_run_context.py diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index 7a3b1c9c3..55bad733c 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, cast from . import _utils -from .run_context import AgentDepsT, RunContext +from ._run_context import AgentDepsT, RunContext from .tools import SystemPromptFunc diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 3e0542b26..a8b46f029 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -20,12 +20,12 @@ TextOutputSchema, ToolOutputSchema, ) +from ._run_context import AgentDepsT, RunContext from .messages import AgentStreamEvent, FinalResultEvent from .output import ( OutputDataT, ToolOutput, ) -from .run_context import AgentDepsT, RunContext from .usage import Usage, UsageLimits __all__ = ( diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index e3c97bed8..d4e3bcd75 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -13,8 +13,8 @@ from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar from . import _function_schema, _utils, messages as _messages +from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, UnexpectedModelBehavior -from .run_context import AgentDepsT, RunContext __all__ = ( 'AgentDepsT', From 8041cf3030e8f0491d3f2724a5c61cf0f8c89f6a Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 19 Jun 2025 18:49:36 +0000 Subject: [PATCH 45/90] Fix thinking part related tests --- tests/models/test_anthropic.py | 2 +- tests/models/test_bedrock.py | 2 +- tests/models/test_google.py | 2 +- tests/models/test_openai.py | 13 +++++++++---- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index f6e0a538a..81151e711 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1077,7 +1077,6 @@ async def test_anthropic_model_thinking_part_stream(allow_model_requests: None, assert event_parts == snapshot( [ PartStartEvent(index=0, part=ThinkingPart(content='', signature='')), - FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), @@ -1088,6 +1087,7 @@ async def test_anthropic_model_thinking_part_stream(allow_model_requests: None, PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartStartEvent(index=1, part=IsInstance(TextPart)), + FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent( index=1, delta=TextPartDelta( diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 29d28a09c..b851962a2 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -712,7 +712,6 @@ async def test_bedrock_model_thinking_part_stream(allow_model_requests: None, be assert event_parts == snapshot( [ PartStartEvent(index=0, part=ThinkingPart(content='Okay')), - FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=', so')), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' the')), PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' user is')), @@ -1136,6 +1135,7 @@ async def test_bedrock_model_thinking_part_stream(allow_model_requests: None, be """ ), ), + FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Crossing the')), PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' street safely involves')), PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' careful')), diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 6ecadd7a3..3a87035dd 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -705,13 +705,13 @@ async def test_google_model_thinking_part_iter(allow_model_requests: None, googl assert event_parts == snapshot( [ PartStartEvent(index=0, part=IsInstance(ThinkingPart)), - FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), PartStartEvent(index=1, part=IsInstance(TextPart)), + FinalResultEvent(tool_name=None, tool_call_id=None), PartDeltaEvent(index=1, delta=IsInstance(TextPartDelta)), PartDeltaEvent(index=1, delta=IsInstance(TextPartDelta)), PartDeltaEvent(index=1, delta=IsInstance(TextPartDelta)), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 229437d12..68e633b1d 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -29,6 +29,7 @@ RetryPromptPart, SystemPromptPart, TextPart, + TextPartDelta, ThinkingPart, ThinkingPartDelta, ToolCallPart, @@ -1796,10 +1797,14 @@ async def test_openai_model_thinking_part_iter(allow_model_requests: None, opena assert event_parts == snapshot( IsListOrTuple( - PartStartEvent(index=0, part=ThinkingPart(content='', signature=IsStr())), - FinalResultEvent(tool_name=None, tool_call_id=None), - PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), - length=(3, ...), + positions={ + 0: PartStartEvent(index=0, part=ThinkingPart(content='', signature=IsStr())), + 1: PartDeltaEvent(index=0, delta=IsInstance(ThinkingPartDelta)), + 87: PartStartEvent(index=1, part=TextPart(content="I'm")), + 88: FinalResultEvent(tool_name=None, tool_call_id=None), + 89: PartDeltaEvent(index=1, delta=IsInstance(TextPartDelta)), + }, + length=443, ) ) From 9bfed04333622306287e72aef99df3998e0481e5 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 20 Jun 2025 23:08:55 +0000 Subject: [PATCH 46/90] Implement Toolset --- docs/mcp/client.md | 23 +- mcp-run-python/README.md | 2 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 228 +++--- pydantic_ai_slim/pydantic_ai/_output.py | 16 +- pydantic_ai_slim/pydantic_ai/_run_context.py | 17 +- pydantic_ai_slim/pydantic_ai/agent.py | 130 ++-- pydantic_ai_slim/pydantic_ai/mcp.py | 78 +- pydantic_ai_slim/pydantic_ai/result.py | 4 +- pydantic_ai_slim/pydantic_ai/tools.py | 115 +-- pydantic_ai_slim/pydantic_ai/toolset.py | 736 +++++++++++++++++++ tests/test_examples.py | 10 +- tests/test_mcp.py | 89 ++- tests/test_tools.py | 64 +- tests/test_toolset.py | 114 +++ 14 files changed, 1177 insertions(+), 449 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/toolset.py create mode 100644 tests/test_toolset.py diff --git a/docs/mcp/client.md b/docs/mcp/client.md index f4d5bcc28..96da38fe0 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -29,7 +29,7 @@ Examples of both are shown below; [mcp-run-python](run-python.md) is used as the [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. !!! note - [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before calling [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not managed by PydanticAI. + [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before calling [`agent.run_toolsets()`][pydantic_ai.Agent.run_toolsets]. Running the server is not managed by PydanticAI. The name "HTTP" is used since this implemented will be adapted in future to use the new [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. @@ -51,7 +51,7 @@ agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! async def main(): - async with agent.run_mcp_servers(): # (3)! + async with agent.run_toolsets(): # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -93,7 +93,7 @@ Will display as follows: !!! note [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be running and accepting HTTP connections before calling - [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not + [`agent.run_toolsets()`][pydantic_ai.Agent.run_toolsets]. Running the server is not managed by PydanticAI. Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. @@ -120,7 +120,7 @@ server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! async def main(): - async with agent.run_mcp_servers(): # (3)! + async with agent.run_toolsets(): # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -137,7 +137,7 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class. !!! note - When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers] context manager is responsible for starting and stopping the server. + When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`agent.run_toolsets()`][pydantic_ai.Agent.run_toolsets] context manager is responsible for starting and stopping the server. ```python {title="mcp_stdio_client.py" py="3.10"} from pydantic_ai import Agent @@ -159,7 +159,7 @@ agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -179,19 +179,20 @@ call needs. from typing import Any from pydantic_ai import Agent -from pydantic_ai.mcp import CallToolFunc, MCPServerStdio, ToolResult +from pydantic_ai.mcp import MCPServerStdio, ToolResult from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext +from pydantic_ai.toolset import CallToolFunc async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, - tool_name: str, - args: dict[str, Any], + name: str, + tool_args: dict[str, Any], ) -> ToolResult: """A tool call processor that passes along the deps.""" - return await call_tool(tool_name, args, metadata={'deps': ctx.deps}) + return await call_tool(name, tool_args, metadata={'deps': ctx.deps}) server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call) @@ -203,7 +204,7 @@ agent = Agent( async def main(): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Echo with deps set to 42', deps=42) print(result.output) #> {"echo_deps":{"echo":"This is an echo message","deps":42}} diff --git a/mcp-run-python/README.md b/mcp-run-python/README.md index 360ca2347..0d57fb762 100644 --- a/mcp-run-python/README.md +++ b/mcp-run-python/README.md @@ -56,7 +56,7 @@ agent = Agent('claude-3-5-haiku-latest', mcp_servers=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025.w diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 864d269a2..bfa2cc8b4 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -3,6 +3,7 @@ import asyncio import dataclasses import hashlib +import json from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -10,20 +11,22 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast from opentelemetry.trace import Tracer +from pydantic import ValidationError from typing_extensions import TypeGuard, TypeVar, assert_never from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._utils import is_async_callable, run_in_executor +from pydantic_ai.toolset import AbstractToolset, CombinedToolset from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage from .output import OutputDataT, OutputSpec from .settings import ModelSettings, merge_model_settings -from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc +from .tools import RunContext if TYPE_CHECKING: - from .mcp import MCPServer + pass __all__ = ( 'GraphAgentState', @@ -103,18 +106,15 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] output_schema: _output.OutputSchema[OutputDataT] + output_toolset: AbstractToolset[DepsT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] history_processors: Sequence[HistoryProcessor[DepsT]] - function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) - mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) - default_retries: int + toolset: AbstractToolset[DepsT] tracer: Tracer - prepare_tools: ToolsPrepareFunc[DepsT] | None = None - class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """The base class for all agent nodes. @@ -244,61 +244,25 @@ async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" - function_tool_defs_map: dict[str, ToolDefinition] = {} - run_context = build_run_context(ctx) + ctx.deps.toolset = toolset = await ctx.deps.toolset.freeze_for_run(run_context) + ctx.deps.output_toolset = output_toolset = await ctx.deps.output_toolset.freeze_for_run(run_context) - async def add_tool(tool: Tool[DepsT]) -> None: - ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name) - if tool_def := await tool.prepare_tool_def(ctx): - # prepare_tool_def may change tool_def.name - if tool_def.name in function_tool_defs_map: - if tool_def.name != tool.name: - # Prepare tool def may have renamed the tool - raise exceptions.UserError( - f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool." - ) - else: - raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.') - function_tool_defs_map[tool_def.name] = tool_def - - async def add_mcp_server_tools(server: MCPServer) -> None: - if not server.is_running: - raise exceptions.UserError(f'MCP server is not running: {server}') - tool_defs = await server.list_tools() - for tool_def in tool_defs: - if tool_def.name in function_tool_defs_map: - raise exceptions.UserError( - f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts." - ) - function_tool_defs_map[tool_def.name] = tool_def - - await asyncio.gather( - *map(add_tool, ctx.deps.function_tools.values()), - *map(add_mcp_server_tools, ctx.deps.mcp_servers), - ) - function_tool_defs = list(function_tool_defs_map.values()) - if ctx.deps.prepare_tools: - # Prepare the tools using the provided function - # This also acts over tool definitions pulled from MCP servers - function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] + # This will raise errors for any name conflicts + CombinedToolset[DepsT]([output_toolset, toolset]) output_schema = ctx.deps.output_schema - - output_tools = [] output_object = None - if isinstance(output_schema, _output.ToolOutputSchema): - output_tools = output_schema.tool_defs() - elif isinstance(output_schema, _output.ModelStructuredOutputSchema): + if isinstance(output_schema, _output.ModelStructuredOutputSchema): output_object = output_schema.object_def # ToolOrTextOutputSchema, ModelStructuredOutputSchema, and PromptedStructuredOutputSchema all inherit from TextOutputSchema allow_text_output = isinstance(output_schema, _output.TextOutputSchema) return models.ModelRequestParameters( - function_tools=function_tool_defs, + function_tools=toolset.tool_defs, output_mode=output_schema.mode, - output_tools=output_tools, + output_tools=output_toolset.tool_defs, output_object=output_object, allow_text_output=allow_text_output, ) @@ -521,6 +485,7 @@ async def _handle_tool_calls( final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] + # TODO: Can we make output tools a toolset? How does CallToolsNode know the result is final, and not be sent back? # first, look for the output tool call if isinstance(output_schema, _output.ToolOutputSchema): for call, output_tool in output_schema.find_tool(tool_calls): @@ -634,9 +599,8 @@ async def process_function_tools( # noqa C901 # we rely on the fact that if we found a result, it's the first output tool in the last found_used_output_tool = False - run_context = build_run_context(ctx) - calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = [] + calls_to_run: list[_messages.ToolCallPart] = [] call_index_to_event_id: dict[int, str] = {} for call in tool_calls: if ( @@ -652,36 +616,7 @@ async def process_function_tools( # noqa C901 tool_call_id=call.tool_call_id, ) ) - elif tool := ctx.deps.function_tools.get(call.tool_name): - if stub_function_tools: - output_parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - event = _messages.FunctionToolCallEvent(call) - yield event - call_index_to_event_id[len(calls_to_run)] = event.call_id - calls_to_run.append((tool, call)) - elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx): - if stub_function_tools: - # TODO(Marcelo): We should add coverage for this part of the code. - output_parts.append( # pragma: no cover - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - event = _messages.FunctionToolCallEvent(call) - yield event - call_index_to_event_id[len(calls_to_run)] = event.call_id - calls_to_run.append((mcp_tool, call)) - elif call.tool_name in output_schema.tools: + elif call.tool_name in output_schema.tools: # TODO: Check on toolset? # if tool_name is in output_schema, it means we found a output tool but an error occurred in # validation, we don't add another part here if output_tool_name is not None: @@ -698,10 +633,24 @@ async def process_function_tools( # noqa C901 ) yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id) output_parts.append(part) + elif call.tool_name in ctx.deps.toolset.tool_names: + if stub_function_tools: + output_parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + event = _messages.FunctionToolCallEvent(call) + yield event + call_index_to_event_id[len(calls_to_run)] = event.call_id + calls_to_run.append(call) else: yield _messages.FunctionToolCallEvent(call) - part = _unknown_tool(call.tool_name, call.tool_call_id, ctx) + part = await _unknown_tool(call.tool_name, call.tool_call_id, ctx) yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id) output_parts.append(part) @@ -715,13 +664,13 @@ async def process_function_tools( # noqa C901 with ctx.deps.tracer.start_as_current_span( 'running tools', attributes={ - 'tools': [call.tool_name for _, call in calls_to_run], + 'tools': [call.tool_name for call in calls_to_run], 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', }, ): tasks = [ - asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer), name=call.tool_name) - for tool, call in calls_to_run + asyncio.create_task(_execute_tool_call(call, ctx, ctx.deps.tracer), name=call.tool_name) + for call in calls_to_run ] pending = tasks @@ -735,17 +684,8 @@ async def process_function_tools( # noqa C901 if isinstance(result, _messages.RetryPromptPart): results_by_index[index] = result elif isinstance(result, _messages.ToolReturnPart): - contents: list[Any] - single_content: bool - if isinstance(result.content, list): - contents = result.content # type: ignore - single_content = False - else: - contents = [result.content] - single_content = True - processed_contents: list[Any] = [] - for content in contents: + def process_content(content: Any) -> Any: if isinstance(content, _messages.MultiModalContentTypes): if isinstance(content, _messages.BinaryContent): identifier = multi_modal_content_identifier(content.data) @@ -759,14 +699,15 @@ async def process_function_tools( # noqa C901 part_kind='user-prompt', ) ) - processed_contents.append(f'See file {identifier}') + return f'See file {identifier}' else: - processed_contents.append(content) + return content - if single_content: - result.content = processed_contents[0] + if isinstance(result.content, list): + contents = cast(list[Any], result.content) # type: ignore + result.content = [process_content(content) for content in contents] else: - result.content = processed_contents + result.content = process_content(result.content) results_by_index[index] = result else: @@ -780,51 +721,74 @@ async def process_function_tools( # noqa C901 output_parts.extend(user_parts) -async def _tool_from_mcp_server( - tool_name: str, +async def _execute_tool_call( + tool_call: _messages.ToolCallPart, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> Tool[DepsT] | None: - """Call each MCP server to find the tool with the given name. + tracer: Tracer, +) -> _messages.ToolReturnPart | _messages.RetryPromptPart: + """Run the tool function asynchronously. - Args: - tool_name: The name of the tool to find. - ctx: The current run context. - - Returns: - The tool with the given name, or `None` if no tool with the given name is found. + See . """ + span_attributes = { + 'gen_ai.tool.name': tool_call.tool_name, + # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai + 'gen_ai.tool.call.id': tool_call.tool_call_id, + 'tool_arguments': tool_call.args_as_json_str(), + 'logfire.msg': f'running tool: {tool_call.tool_name}', + # add the JSON schema so these attributes are formatted nicely in Logfire + 'logfire.json_schema': json.dumps( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ), + } - async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any: - # There's no normal situation where the server will not be running at this point, we check just in case - # some weird edge case occurs. - if not server.is_running: # pragma: no cover - raise exceptions.UserError(f'MCP server is not running: {server}') - - if server.process_tool_call is not None: - result = await server.process_tool_call(ctx, server.call_tool, tool_name, args) - else: - result = await server.call_tool(tool_name, args) + run_context = build_run_context(ctx) + toolset = ctx.deps.toolset + with tracer.start_as_current_span('running tool', attributes=span_attributes): + run_context = dataclasses.replace(run_context, tool_call_id=tool_call.tool_call_id) - return result + try: + args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args) + except ValidationError as e: + return _messages.RetryPromptPart( + tool_name=tool_call.tool_name, + content=e.errors(include_url=False, include_context=False), + tool_call_id=tool_call.tool_call_id, + ) + try: + response_content = await toolset.call_tool(run_context, tool_call.tool_name, args_dict) + except exceptions.ModelRetry as e: + return _messages.RetryPromptPart( + tool_name=tool_call.tool_name, + content=e.message, + tool_call_id=tool_call.tool_call_id, + ) - for server in ctx.deps.mcp_servers: - tools = await server.list_tools() - if tool_name in {tool.name for tool in tools}: # pragma: no branch - return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries) - return None + return _messages.ToolReturnPart( + tool_name=tool_call.tool_name, + content=response_content, + tool_call_id=tool_call.tool_call_id, + ) -def _unknown_tool( +async def _unknown_tool( tool_name: str, tool_call_id: str, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> _messages.RetryPromptPart: ctx.state.increment_retries(ctx.deps.max_result_retries) - tool_names = list(ctx.deps.function_tools.keys()) - output_schema = ctx.deps.output_schema - if isinstance(output_schema, _output.ToolOutputSchema): - tool_names.extend(output_schema.tool_names()) + tool_names = [ + *ctx.deps.toolset.tool_names, + *ctx.deps.output_toolset.tool_names, + ] if tool_names: msg = f'Available tools: {", ".join(tool_names)}' diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 359153272..230fb4d38 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -4,7 +4,7 @@ import json from abc import ABC, abstractmethod from collections.abc import Awaitable, Iterable, Iterator, Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload from pydantic import TypeAdapter, ValidationError @@ -95,7 +95,11 @@ async def validate( Result of either the validated result data (ok) or a retry message (Err). """ if self._takes_ctx: - ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None) + ctx = ( + replace(run_context, tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id) + if tool_call + else run_context + ) args = ctx, result else: args = (result,) @@ -502,14 +506,6 @@ def tools(self) -> dict[str, OutputTool[OutputDataT]]: """Get the tools for this output schema.""" return self._tools - def tool_names(self) -> list[str]: - """Return the names of the tools.""" - return list(self.tools.keys()) - - def tool_defs(self) -> list[ToolDefinition]: - """Get tool definitions to register with the model.""" - return [t.tool_def for t in self.tools.values()] - def find_named_tool( self, parts: Iterable[_messages.ModelResponsePart], tool_name: str ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index bb7f47420..5f705cd4f 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -27,10 +27,12 @@ class RunContext(Generic[AgentDepsT]): """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" - prompt: str | Sequence[_messages.UserContent] | None + prompt: str | Sequence[_messages.UserContent] | None = None """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" + retries: dict[str, int] = field(default_factory=dict) + """Number of retries for each tool.""" tool_call_id: str | None = None """The ID of the tool call.""" tool_name: str | None = None @@ -40,17 +42,4 @@ class RunContext(Generic[AgentDepsT]): run_step: int = 0 """The current step in the run.""" - def replace_with( - self, - retry: int | None = None, - tool_name: str | None | _utils.Unset = _utils.UNSET, - ) -> RunContext[AgentDepsT]: - # Create a new `RunContext` a new `retry` value and `tool_name`. - kwargs = {} - if retry is not None: - kwargs['retry'] = retry - if tool_name is not _utils.UNSET: # pragma: no branch - kwargs['tool_name'] = tool_name - return dataclasses.replace(self, **kwargs) - __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 421da8f44..0c508a2da 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -5,7 +5,7 @@ import json import warnings from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from copy import deepcopy from types import FrameType from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload @@ -15,6 +15,7 @@ from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated from pydantic_ai.profiles import ModelProfile +from pydantic_ai.toolset import AbstractToolset, CombinedToolset, FunctionToolset, OutputToolset, PreparedToolset from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -152,10 +153,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( repr=False ) - _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False) - _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) - _default_retries: int = dataclasses.field(repr=False) + _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) + _output_toolset: OutputToolset[AgentDepsT] = dataclasses.field(repr=False) + _toolset: AbstractToolset[AgentDepsT] = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False) _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False) @@ -179,6 +179,7 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), + toolsets: Sequence[AbstractToolset[AgentDepsT]] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -209,6 +210,7 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), + toolsets: Sequence[AbstractToolset[AgentDepsT]] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -234,6 +236,7 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), + toolsets: Sequence[AbstractToolset[AgentDepsT]] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -267,6 +270,7 @@ def __init__( a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer] for each server you want the agent to connect to. + toolsets: Toolsets to register with the agent. defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` @@ -352,18 +356,24 @@ def __init__( self._system_prompt_functions = [] self._system_prompt_dynamic_functions = {} - self._function_tools = {} - - self._default_retries = retries self._max_result_retries = output_retries if output_retries is not None else retries - self._mcp_servers = mcp_servers - self._prepare_tools = prepare_tools + + self._output_toolset = OutputToolset[AgentDepsT](self._output_schema, max_retries=self._max_result_retries) + self._function_toolset = FunctionToolset[AgentDepsT](tools, max_retries=retries) + + # This will raise errors for any name conflicts + CombinedToolset[AgentDepsT]([self._output_toolset, self._function_toolset]) + + mcp_toolsets = [ + cast(AbstractToolset[AgentDepsT], mcp_server.as_toolset(max_retries=self._max_result_retries)) + for mcp_server in mcp_servers + ] + toolset = CombinedToolset[AgentDepsT]([self._function_toolset, *toolsets, *mcp_toolsets]) + if prepare_tools: + toolset = PreparedToolset[AgentDepsT](toolset, prepare_tools) + self._toolset = toolset + self.history_processors = history_processors or [] - for tool in tools: - if isinstance(tool, Tool): - self._register_tool(tool) - else: - self._register_tool(Tool(tool)) @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: @@ -643,6 +653,12 @@ async def main(): output_type_ = output_type or self.output_type + output_toolset = self._output_toolset + if output_schema != self._output_schema: + output_toolset = OutputToolset[AgentDepsT](output_schema, max_retries=self._max_result_retries) + # This will raise errors for any name conflicts + CombinedToolset[AgentDepsT]([output_toolset, self._toolset]) + # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) @@ -697,10 +713,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: return None return '\n\n'.join(parts).strip() - # Copy the function tools so that retry state is agent-run-specific - # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`. - run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()} - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( user_deps=deps, prompt=user_prompt, @@ -712,12 +724,10 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: end_strategy=self.end_strategy, output_schema=output_schema, output_validators=output_validators, + output_toolset=output_toolset, history_processors=self.history_processors, - function_tools=run_function_tools, - mcp_servers=self._mcp_servers, - default_retries=self._default_retries, + toolset=self._toolset, tracer=tracer, - prepare_tools=self._prepare_tools, get_instructions=get_instructions, ) start_node = _agent_graph.UserPromptNode[AgentDepsT]( @@ -1407,7 +1417,7 @@ def tool_decorator( func_: ToolFuncContext[AgentDepsT, ToolParams], ) -> ToolFuncContext[AgentDepsT, ToolParams]: # noinspection PyTypeChecker - self._register_function( + self._function_toolset.register_function( func_, True, name, @@ -1423,7 +1433,7 @@ def tool_decorator( return tool_decorator else: # noinspection PyTypeChecker - self._register_function( + self._function_toolset.register_function( func, True, name, @@ -1514,7 +1524,7 @@ async def spam(ctx: RunContext[str]) -> float: def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: # noinspection PyTypeChecker - self._register_function( + self._function_toolset.register_function( func_, False, name, @@ -1529,7 +1539,7 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams return tool_decorator else: - self._register_function( + self._function_toolset.register_function( func, False, name, @@ -1542,47 +1552,6 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams ) return func - def _register_function( - self, - func: ToolFuncEither[AgentDepsT, ToolParams], - takes_ctx: bool, - name: str | None, - retries: int | None, - prepare: ToolPrepareFunc[AgentDepsT] | None, - docstring_format: DocstringFormat, - require_parameter_descriptions: bool, - schema_generator: type[GenerateJsonSchema], - strict: bool | None, - ) -> None: - """Private utility to register a function as a tool.""" - retries_ = retries if retries is not None else self._default_retries - tool = Tool[AgentDepsT]( - func, - takes_ctx=takes_ctx, - name=name, - max_retries=retries_, - prepare=prepare, - docstring_format=docstring_format, - require_parameter_descriptions=require_parameter_descriptions, - schema_generator=schema_generator, - strict=strict, - ) - self._register_tool(tool) - - def _register_tool(self, tool: Tool[AgentDepsT]) -> None: - """Private utility to register a tool instance.""" - if tool.max_retries is None: - # noinspection PyTypeChecker - tool = dataclasses.replace(tool, max_retries=self._default_retries) - - if tool.name in self._function_tools: - raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') - - if tool.name in self._output_schema.tools: - raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}') - - self._function_tools[tool.name] = tool - def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model: """Create a model configured for this agent. @@ -1713,7 +1682,7 @@ def is_end_node( return isinstance(node, End) @asynccontextmanager - async def run_mcp_servers( + async def run_toolsets( self, model: models.Model | models.KnownModelName | str | None = None ) -> AsyncIterator[None]: """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. @@ -1724,16 +1693,23 @@ async def run_mcp_servers( sampling_model: models.Model | None = self._get_model(model) except exceptions.UserError: # pragma: no cover sampling_model = None + if sampling_model is not None: # pragma: no branch + self._toolset.set_mcp_sampling_model(sampling_model) - exit_stack = AsyncExitStack() - try: - for mcp_server in self._mcp_servers: - if sampling_model is not None: # pragma: no branch - mcp_server.sampling_model = sampling_model - await exit_stack.enter_async_context(mcp_server) + async with self._toolset: + yield + + @asynccontextmanager + @deprecated('`run_mcp_servers` is deprecated, use `run_toolsets` instead.') + async def run_mcp_servers( + self, model: models.Model | models.KnownModelName | str | None = None + ) -> AsyncIterator[None]: + """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. + + Returns: a context manager to start and shutdown the servers. + """ + async with self.run_toolsets(model): yield - finally: - await exit_stack.aclose() def to_a2a( self, diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 9574b077f..9574a830a 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -3,7 +3,7 @@ import base64 import functools from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Awaitable, Sequence +from collections.abc import AsyncIterator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from dataclasses import dataclass from pathlib import Path @@ -16,6 +16,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from typing_extensions import Self, assert_never, deprecated +from .exceptions import UserError +from .toolset import AbstractToolset, MCPToolset, PrefixedToolset, ProcessedToolset, ToolProcessFunc + try: from mcp import types as mcp_types from mcp.client.session import ClientSession, LoggingFnT @@ -32,7 +35,7 @@ ) from _import_error # after mcp imports so any import error maps to this file, not _mcp.py -from . import _mcp, exceptions, messages, models, tools +from . import _mcp, exceptions, messages, models __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP' @@ -48,7 +51,7 @@ class MCPServer(ABC): log_level: mcp_types.LoggingLevel | None = None log_handler: LoggingFnT | None = None init_timeout: float = 5 - process_tool_call: ProcessToolCallback | None = None + process_tool_call: ToolProcessFunc | None = None # } end of "abstract fields" _running_count: int = 0 @@ -73,35 +76,22 @@ async def client_streams( raise NotImplementedError('MCP Server subclasses must implement this method.') yield - def get_prefixed_tool_name(self, tool_name: str) -> str: - """Get the tool name with prefix if `tool_prefix` is set.""" - return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name - - def get_unprefixed_tool_name(self, tool_name: str) -> str: - """Get original tool name without prefix for calling tools.""" - return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name - @property def is_running(self) -> bool: """Check if the MCP server is running.""" return bool(self._running_count) - async def list_tools(self) -> list[tools.ToolDefinition]: + async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. Note: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ - mcp_tools = await self._client.list_tools() - return [ - tools.ToolDefinition( - name=self.get_prefixed_tool_name(tool.name), - description=tool.description or '', - parameters_json_schema=tool.inputSchema, - ) - for tool in mcp_tools.tools - ] + if not self.is_running: # pragma: no cover + raise UserError(f'MCP server is not running: {self}') + result = await self._client.list_tools() + return result.tools async def call_tool( self, @@ -122,6 +112,8 @@ async def call_tool( Raises: ModelRetry: If the tool call fails. """ + if not self.is_running: # pragma: no cover + raise UserError(f'MCP server is not running: {self}') try: # meta param is not provided by session yet, so build and can send_request directly. result = await self._client.send_request( @@ -129,7 +121,7 @@ async def call_tool( mcp_types.CallToolRequest( method='tools/call', params=mcp_types.CallToolRequestParams( - name=self.get_unprefixed_tool_name(tool_name), + name=tool_name, arguments=arguments, _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, ), @@ -148,6 +140,14 @@ async def call_tool( else: return content[0] if len(content) == 1 else content + def as_toolset(self, max_retries: int = 1) -> AbstractToolset: + toolset = MCPToolset(self, max_retries=max_retries) + if self.process_tool_call: + toolset = ProcessedToolset(toolset, self.process_tool_call) + if self.tool_prefix: + toolset = PrefixedToolset(toolset, self.tool_prefix) + return toolset + async def __aenter__(self) -> Self: if self._running_count == 0: self._exit_stack = AsyncExitStack() @@ -273,7 +273,7 @@ class MCPServerStdio(MCPServer): agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent.run_toolsets(): # (2)! ... ``` @@ -323,7 +323,7 @@ async def main(): timeout: float = 5 """ The timeout in seconds to wait for the client to initialize.""" - process_tool_call: ProcessToolCallback | None = None + process_tool_call: ToolProcessFunc | None = None """Hook to customize tool calling and optionally pass extra metadata.""" @asynccontextmanager @@ -418,7 +418,7 @@ class _MCPServerHTTP(MCPServer): init_timeout: float = 5 """The timeout in seconds to wait for the client to initialize.""" - process_tool_call: ProcessToolCallback | None = None + process_tool_call: ToolProcessFunc | None = None """Hook to customize tool calling and optionally pass extra metadata.""" @property @@ -505,7 +505,7 @@ class MCPServerSSE(_MCPServerHTTP): agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent.run_toolsets(): # (2)! ... ``` @@ -539,7 +539,7 @@ class MCPServerHTTP(MCPServerSSE): agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent.run_toolsets(): # (2)! ... ``` @@ -568,7 +568,7 @@ class MCPServerStreamableHTTP(_MCPServerHTTP): agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent.run_toolsets(): # (2)! ... ``` """ @@ -585,24 +585,4 @@ def _transport_client(self): | list[Any] | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] ) -"""The result type of a tool call.""" - -CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]] -"""A function type that represents a tool call.""" - -ProcessToolCallback = Callable[ - [ - tools.RunContext[Any], - CallToolFunc, - str, - dict[str, Any], - ], - Awaitable[ToolResult], -] -"""A process tool callback. - -It accepts a run context, the original tool call function, a tool name, and arguments. - -Allows wrapping an MCP server tool call to customize it, including adding extra request -metadata. -""" +"""The result type of an MCP tool call.""" diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index a8b46f029..d7d6a51cb 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -95,7 +95,7 @@ async def _validate_response( match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + 'Invalid response, unable to find tool' ) call, output_tool = match @@ -413,7 +413,7 @@ async def validate_structured_output( match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + 'Invalid response, unable to find tool' ) call, output_tool = match diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index d4e3bcd75..200b17506 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,20 +1,15 @@ from __future__ import annotations as _annotations -import dataclasses -import json from collections.abc import Awaitable, Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Generic, Literal, Union -from opentelemetry.trace import Tracer -from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import SchemaValidator, core_schema from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar -from . import _function_schema, _utils, messages as _messages +from . import _function_schema, _utils from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UnexpectedModelBehavior __all__ = ( 'AgentDepsT', @@ -172,12 +167,6 @@ class Tool(Generic[AgentDepsT]): This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request. """ - # TODO: Consider moving this current_retry state to live on something other than the tool. - # We've worked around this for now by copying instances of the tool when creating new runs, - # but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things - # up, though is also likely a larger effort to refactor. - current_retry: int = field(default=0, init=False) - def __init__( self, function: ToolFuncEither[AgentDepsT], @@ -302,6 +291,15 @@ def from_schema( function_schema=function_schema, ) + @property + def tool_def(self): + return ToolDefinition( + name=self.name, + description=self.description, + parameters_json_schema=self.function_schema.json_schema, + strict=self.strict, + ) + async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. @@ -311,96 +309,11 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition Returns: return a `ToolDefinition` or `None` if the tools should not be registered for this run. """ - tool_def = ToolDefinition( - name=self.name, - description=self.description, - parameters_json_schema=self.function_schema.json_schema, - strict=self.strict, - ) + standard_tool_def = self.tool_def if self.prepare is not None: - return await self.prepare(ctx, tool_def) - else: - return tool_def - - async def run( - self, - message: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - tracer: Tracer, - ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: - """Run the tool function asynchronously. - - This method wraps `_run` in an OpenTelemetry span. - - See . - """ - span_attributes = { - 'gen_ai.tool.name': self.name, - # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai - 'gen_ai.tool.call.id': message.tool_call_id, - 'tool_arguments': message.args_as_json_str(), - 'logfire.msg': f'running tool: {self.name}', - # add the JSON schema so these attributes are formatted nicely in Logfire - 'logfire.json_schema': json.dumps( - { - 'type': 'object', - 'properties': { - 'tool_arguments': {'type': 'object'}, - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ), - } - with tracer.start_as_current_span('running tool', attributes=span_attributes): - return await self._run(message, run_context) - - async def _run( - self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] - ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: - try: - validator = self.function_schema.validator - if isinstance(message.args, str): - args_dict = validator.validate_json(message.args or '{}') - else: - args_dict = validator.validate_python(message.args or {}) - except ValidationError as e: - return self._on_error(e, message) - - ctx = dataclasses.replace( - run_context, - retry=self.current_retry, - tool_name=message.tool_name, - tool_call_id=message.tool_call_id, - ) - try: - response_content = await self.function_schema.call(args_dict, ctx) - except ModelRetry as e: - return self._on_error(e, message) - - self.current_retry = 0 - return _messages.ToolReturnPart( - tool_name=message.tool_name, - content=response_content, - tool_call_id=message.tool_call_id, - ) - - def _on_error( - self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart - ) -> _messages.RetryPromptPart: - self.current_retry += 1 - if self.max_retries is None or self.current_retry > self.max_retries: - raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc + return await self.prepare(ctx, standard_tool_def) else: - if isinstance(exc, ValidationError): - content = exc.errors(include_url=False, include_context=False) - else: - content = exc.message - return _messages.RetryPromptPart( - tool_name=call_message.tool_name, - content=content, - tool_call_id=call_message.tool_call_id, - ) + return standard_tool_def ObjectJsonSchema: TypeAlias = dict[str, Any] diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py new file mode 100644 index 000000000..bd58c53b1 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -0,0 +1,736 @@ +from __future__ import annotations + +import asyncio +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Sequence +from contextlib import AsyncExitStack +from dataclasses import dataclass, field, replace +from functools import partial +from types import TracebackType +from typing import TYPE_CHECKING, Any, Callable, Generic, Never, Protocol, overload + +from pydantic import ValidationError +from pydantic.json_schema import GenerateJsonSchema +from pydantic_core import SchemaValidator, core_schema +from typing_extensions import Self + +from ._output import BaseOutputSchema +from ._run_context import AgentDepsT, RunContext +from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError +from .tools import ( + DocstringFormat, + GenerateToolJsonSchema, + Tool, + ToolDefinition, + ToolFuncEither, + ToolParams, + ToolPrepareFunc, + ToolsPrepareFunc, +) + +if TYPE_CHECKING: + from pydantic_ai.mcp import MCPServer + from pydantic_ai.models import Model + + +class AbstractToolset(ABC, Generic[AgentDepsT]): + """A toolset is a collection of tools that can be used by an agent. + + It is responsible for: + - Listing the tools it contains + - Validating the arguments of the tools + - Calling the tools + """ + + @property + def name(self) -> str: + return self.__class__.__name__ + + @property + def name_conflict_hint(self) -> str: + return 'Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.' + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + return None + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + return FrozenToolset[AgentDepsT](self, ctx) + + @property + @abstractmethod + def tool_defs(self) -> list[ToolDefinition]: + raise NotImplementedError() + + @property + def tool_names(self) -> list[str]: + return [tool_def.name for tool_def in self.tool_defs] + + @abstractmethod + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + raise NotImplementedError() + + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None + ) -> dict[str, Any]: + validator = self.get_tool_args_validator(ctx, name) + if isinstance(args, str): + return validator.validate_json(args or '{}') + else: + return validator.validate_python(args or {}) + + @abstractmethod + def max_retries_for_tool(self, name: str) -> int: + raise NotImplementedError() + + @abstractmethod + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + raise NotImplementedError() + + def set_mcp_sampling_model(self, model: Model) -> None: + pass + + +@dataclass(init=False) +class _BareFunctionToolset(AbstractToolset[AgentDepsT]): + """A toolset that functions can be registered to as tools.""" + + max_retries: int = field(default=1) + tools: dict[str, Tool[Any]] = field(default_factory=dict) + + @property + def name(self) -> str: + return 'FunctionToolset' + + def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + self.max_retries = max_retries + self.tools = {} + for tool in tools: + if isinstance(tool, Tool): + self.register_tool(tool) + else: + self.register_function(tool) + + @overload + def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... + + @overload + def tool( + self, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ... + + def tool( + self, + func: ToolFuncEither[AgentDepsT, ToolParams] | None = None, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Any: + """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. + + Can decorate a sync or async functions. + + The docstring is inspected to extract both the tool description and description of each parameter, + [learn more](../tools.md#function-tools-and-schema). + + We can't add overloads for every possible signature of tool, since the return type is a recursive union + so the signature of functions decorated with `@agent.tool` is obscured. + + Example: + ```python + from pydantic_ai import Agent, RunContext + + agent = Agent('test', deps_type=int) + + @agent.tool + def foobar(ctx: RunContext[int], x: int) -> int: + return ctx.deps + x + + @agent.tool(retries=2) + async def spam(ctx: RunContext[str], y: float) -> float: + return ctx.deps + y + + result = agent.run_sync('foobar', deps=1) + print(result.output) + #> {"foobar":1,"spam":1.0} + ``` + + Args: + func: The tool function to register. + name: The name of the tool, defaults to the function name. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, + which defaults to 1. + prepare: custom method to prepare the tool definition for each step, return `None` to omit this + tool from a given step. This is useful if you want to customise a tool at call time, + or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. + docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. + Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. + require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. + schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. + strict: Whether to enforce JSON schema compliance (only affects OpenAI). + See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + """ + if func is None: + + def tool_decorator( + func_: ToolFuncEither[AgentDepsT, ToolParams], + ) -> ToolFuncEither[AgentDepsT, ToolParams]: + # noinspection PyTypeChecker + self.register_function( + func_, + None, + name, + retries, + prepare, + docstring_format, + require_parameter_descriptions, + schema_generator, + strict, + ) + return func_ + + return tool_decorator + else: + # noinspection PyTypeChecker + self.register_function( + func, + None, + name, + retries, + prepare, + docstring_format, + require_parameter_descriptions, + schema_generator, + strict, + ) + return func + + def register_function( + self, + func: ToolFuncEither[AgentDepsT, ToolParams], + takes_ctx: bool | None = None, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> None: + """Register a function as a tool.""" + tool = Tool[AgentDepsT]( + func, + takes_ctx=takes_ctx, + name=name, + max_retries=retries, + prepare=prepare, + docstring_format=docstring_format, + require_parameter_descriptions=require_parameter_descriptions, + schema_generator=schema_generator, + strict=strict, + ) + self.register_tool(tool) + + def register_tool(self, tool: Tool[AgentDepsT]) -> None: + if tool.name in self.tools: + raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}') + if tool.max_retries is None: + tool.max_retries = self.max_retries + self.tools[tool.name] = tool + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [tool.tool_def for tool in self.tools.values()] + + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.tools[name].function_schema.validator + + def max_retries_for_tool(self, name: str) -> int: + tool = self.tools[name] + return tool.max_retries if tool.max_retries is not None else self.max_retries + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.tools[name].function_schema.call(tool_args, ctx) + + +@dataclass +class OutputToolset(AbstractToolset[AgentDepsT]): + """A toolset that contains output tools.""" + + output_schema: BaseOutputSchema[Any] + max_retries: int = field(default=1) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [tool.tool_def for tool in self.output_schema.tools.values()] + + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + # TODO: Should never be called for an output tool? + return self.output_schema.tools[name].processor._validator # pyright: ignore[reportPrivateUsage] + + def max_retries_for_tool(self, name: str) -> int: + return self.max_retries + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + # TODO: Should never be called for an output tool? + return await self.output_schema.tools[name].processor.process(tool_args, ctx) + + +@dataclass +class MCPToolset(AbstractToolset[AgentDepsT]): + """A toolset that contains MCP tools and handles running the server.""" + + server: MCPServer + max_retries: int = field(default=1) + + @property + def name(self) -> str: + return repr(self.server) + + @property + def name_conflict_hint(self) -> str: + return 'Consider setting `tool_prefix` to avoid name conflicts.' + + async def __aenter__(self) -> Self: + await self.server.__aenter__() + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + return await self.server.__aexit__(exc_type, exc_value, traceback) + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + return FrozenToolset[AgentDepsT](self, ctx, await self.list_tool_defs()) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [] + + async def list_tool_defs(self) -> list[ToolDefinition]: + mcp_tools = await self.server.list_tools() + return [ + ToolDefinition( + name=mcp_tool.name, + description=mcp_tool.description or '', + parameters_json_schema=mcp_tool.inputSchema, + ) + for mcp_tool in mcp_tools + ] + + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return SchemaValidator(schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema())) + + def max_retries_for_tool(self, name: str) -> int: + return self.max_retries + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], metadata: dict[str, Any] | None = None + ) -> Any: + return await self.server.call_tool(name, tool_args, metadata) + + def set_mcp_sampling_model(self, model: Model) -> None: + self.server.sampling_model = model + + +@dataclass +class WrapperToolset(AbstractToolset[AgentDepsT]): + """A toolset that wraps another toolset and delegates to it.""" + + wrapped: AbstractToolset[AgentDepsT] + + @property + def name(self) -> str: + return self.wrapped.name + + @property + def name_conflict_hint(self) -> str: + return self.wrapped.name_conflict_hint + + async def __aenter__(self) -> Self: + await self.wrapped.__aenter__() + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + return await self.wrapped.__aexit__(exc_type, exc_value, traceback) + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + raise NotImplementedError() + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self.wrapped.tool_defs + + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.wrapped.get_tool_args_validator(ctx, name) + + def max_retries_for_tool(self, name: str) -> int: + return self.wrapped.max_retries_for_tool(name) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) + + def set_mcp_sampling_model(self, model: Model) -> None: + self.wrapped.set_mcp_sampling_model(model) + + def __getattr__(self, item: str): + return getattr(self.wrapped, item) # pragma: no cover + + +@dataclass(init=False) +class CombinedToolset(AbstractToolset[AgentDepsT]): + """A toolset that combines multiple toolsets.""" + + toolsets: list[AbstractToolset[AgentDepsT]] + _exit_stack: AsyncExitStack | None + _toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]] + + def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): + self._exit_stack = None + self.toolsets = list(toolsets) + + self._toolset_per_tool_name = {} + for toolset in self.toolsets: + for name in toolset.tool_names: + try: + existing_toolset = self._toolset_per_tool_name[name] + raise UserError( + f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_toolset.name}: {name!r}. {toolset.name_conflict_hint}' + ) + except KeyError: + pass + self._toolset_per_tool_name[name] = toolset + + async def __aenter__(self) -> Self: + # TODO: running_count thing like in MCPServer? + self._exit_stack = AsyncExitStack() + for toolset in self.toolsets: + await self._exit_stack.enter_async_context(toolset) + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + if self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + return None + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + frozen_toolsets = await asyncio.gather(*[toolset.freeze_for_run(ctx) for toolset in self.toolsets]) + freezable_combined = CombinedToolset[AgentDepsT](frozen_toolsets) + return FrozenToolset[AgentDepsT](freezable_combined, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [tool_def for toolset in self.toolsets for tool_def in toolset.tool_defs] + + @property + def tool_names(self) -> list[str]: + return list(self._toolset_per_tool_name.keys()) + + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self._toolset_for_tool_name(name).get_tool_args_validator(ctx, name) + + def max_retries_for_tool(self, name: str) -> int: + return self._toolset_for_tool_name(name).max_retries_for_tool(name) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs) + + def set_mcp_sampling_model(self, model: Model) -> None: + for toolset in self.toolsets: + toolset.set_mcp_sampling_model(model) + + def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]: + try: + return self._toolset_per_tool_name[name] + except KeyError as e: + raise ValueError(f'Tool {name!r} not found in any toolset') from e + + +@dataclass +class PrefixedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prefixes the names of the tools it contains.""" + + prefix: str + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + frozen_wrapped = await self.wrapped.freeze_for_run(ctx) + freezable_prefixed = PrefixedToolset[AgentDepsT](frozen_wrapped, self.prefix) + return FrozenToolset[AgentDepsT](freezable_prefixed, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [replace(tool_def, name=self._prefixed_tool_name(tool_def.name)) for tool_def in super().tool_defs] + + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return super().get_tool_args_validator(ctx, self._unprefixed_tool_name(name)) + + def max_retries_for_tool(self, name: str) -> int: + return super().max_retries_for_tool(self._unprefixed_tool_name(name)) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await super().call_tool(ctx, self._unprefixed_tool_name(name), tool_args, *args, **kwargs) + + def _prefixed_tool_name(self, tool_name: str) -> str: + return f'{self.prefix}_{tool_name}' + + def _unprefixed_tool_name(self, tool_name: str) -> str: + full_prefix = f'{self.prefix}_' + if not tool_name.startswith(full_prefix): + raise ValueError(f"Tool name '{tool_name}' does not start with prefix '{full_prefix}'") + return tool_name[len(full_prefix) :] + + +@dataclass +class PreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a prepare function.""" + + prepare_func: ToolsPrepareFunc[AgentDepsT] + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + frozen_wrapped = await self.wrapped.freeze_for_run(ctx) + original_tool_defs = frozen_wrapped.tool_defs + prepared_tool_defs = await self.prepare_func(ctx, original_tool_defs) or [] + + original_tool_names = {tool_def.name for tool_def in original_tool_defs} + prepared_tool_names = {tool_def.name for tool_def in prepared_tool_defs} + if len(prepared_tool_names - original_tool_names) > 0: + raise UserError('Prepare function is not allowed to change tool names or add new tools.') + + freezable_prepared = PreparedToolset[AgentDepsT](frozen_wrapped, self.prepare_func) + return FrozenToolset[AgentDepsT](freezable_prepared, ctx, prepared_tool_defs) + + +@dataclass +class IndividuallyPreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a per-tool prepare function.""" + + prepare_func: ToolPrepareFunc[AgentDepsT] + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + frozen_wrapped = await self.wrapped.freeze_for_run(ctx) + + tool_defs: dict[str, ToolDefinition] = {} + name_map: dict[str, str] = {} + for original_tool_def in frozen_wrapped.tool_defs: + original_name = original_tool_def.name + tool_def = await self.prepare_func(ctx, original_tool_def) + if not tool_def: + continue + + new_name = tool_def.name + if new_name in tool_defs: + if new_name != original_name: + raise UserError(f"Renaming tool '{original_name}' to '{new_name}' conflicts with existing tool.") + else: + raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') + name_map[new_name] = original_name + + tool_defs[new_name] = tool_def + + freezable_prepared = _FreezableIndividuallyPreparedToolset( + frozen_wrapped, self.prepare_func, list(tool_defs.values()), name_map + ) + return FrozenToolset[AgentDepsT](freezable_prepared, ctx) + + +@dataclass +class FunctionToolset(IndividuallyPreparedToolset[AgentDepsT]): + """A toolset that contains function tools.""" + + def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + wrapped = _BareFunctionToolset(tools, max_retries) + + async def prepare_tool_def(ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: + tool_name = tool_def.name + ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) + return await wrapped.tools[tool_name].prepare_tool_def(ctx) + + super().__init__(wrapped, prepare_tool_def) + + +@dataclass(init=False) +class _FreezableIndividuallyPreparedToolset(IndividuallyPreparedToolset[AgentDepsT]): + name_map: dict[str, str] + _tool_defs: list[ToolDefinition] + + def __init__( + self, + wrapped: AbstractToolset[AgentDepsT], + prepare_func: ToolPrepareFunc[AgentDepsT], + tool_defs: list[ToolDefinition], + name_map: dict[str, str], + ): + super().__init__(wrapped, prepare_func) + self._tool_defs = tool_defs + self.name_map = name_map + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self._tool_defs + + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return super().get_tool_args_validator(ctx, self._map_name(name)) + + def max_retries_for_tool(self, name: str) -> int: + return super().max_retries_for_tool(self._map_name(name)) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await super().call_tool(ctx, self._map_name(name), tool_args, *args, **kwargs) + + def _map_name(self, name: str) -> str: + return self.name_map.get(name, name) + + +@dataclass(init=False) +class FilteredToolset(IndividuallyPreparedToolset[AgentDepsT]): + """A toolset that filters the tools it contains using a filter function.""" + + def __init__( + self, + toolset: AbstractToolset[AgentDepsT], + filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool], + ): + async def filter_tool_def(ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: + return tool_def if filter_func(ctx, tool_def) else None + + super().__init__(toolset, filter_tool_def) + + +class CallToolFunc(Protocol): + """A function protocol that represents a tool call.""" + + def __call__(self, name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any) -> Awaitable[Any]: ... + + +ToolProcessFunc = Callable[ + [ + RunContext[AgentDepsT], + CallToolFunc, + str, + dict[str, Any], + ], + Awaitable[Any], +] + + +@dataclass +class ProcessedToolset(WrapperToolset[AgentDepsT]): + """A toolset that lets the tool call arguments and return value be customized using a process function.""" + + process: ToolProcessFunc[AgentDepsT] + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + frozen_wrapped = await self.wrapped.freeze_for_run(ctx) + processed = ProcessedToolset[AgentDepsT](frozen_wrapped, self.process) + return FrozenToolset[AgentDepsT](processed, ctx) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.process(ctx, partial(self.wrapped.call_tool, ctx), name, tool_args, *args, **kwargs) + + +@dataclass(init=False) +class FrozenToolset(WrapperToolset[AgentDepsT]): + """A toolset that is frozen for a specific run.""" + + ctx: RunContext[AgentDepsT] + _tool_defs: list[ToolDefinition] + _tool_names: list[str] + _retries: dict[str, int] + + def __init__( + self, + wrapped: AbstractToolset[AgentDepsT], + ctx: RunContext[AgentDepsT], + tool_defs: list[ToolDefinition] | None = None, + ): + self.wrapped = wrapped + self.ctx = ctx + self._tool_defs = wrapped.tool_defs if tool_defs is None else tool_defs + self._tool_names = [tool_def.name for tool_def in self._tool_defs] + self._retries = ctx.retries.copy() + + @property + def name(self) -> str: + return self.wrapped.name + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + if ctx == self.ctx: + return self + else: + ctx = replace(ctx, retries=self._retries) + return await self.wrapped.freeze_for_run(ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self._tool_defs + + @property + def tool_names(self) -> list[str]: + return self._tool_names + + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None + ) -> dict[str, Any]: + ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0)) + try: + return super().validate_tool_args(ctx, name, args) + except ValidationError as e: + return self._on_error(name, e) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0)) + try: + return await super().call_tool(ctx, name, tool_args, *args, **kwargs) + except ModelRetry as e: + return self._on_error(name, e) + + def _on_error(self, name: str, e: Exception) -> Never: + max_retries = self.max_retries_for_tool(name) + current_retry = self._retries.get(name, 0) + if current_retry == max_retries: + raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {max_retries}') from e + else: + self._retries[name] = current_retry + 1 + raise e diff --git a/tests/test_examples.py b/tests/test_examples.py index d0a2706ad..30e114043 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -36,6 +36,7 @@ from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.toolset import AbstractToolset, MCPToolset from .conftest import ClientWithHandler, TestEnv, try_import @@ -261,10 +262,15 @@ async def __aenter__(self) -> MockMCPServer: async def __aexit__(self, *args: Any) -> None: pass - @staticmethod - async def list_tools() -> list[None]: + async def list_tools(self) -> list[None]: return [] + async def call_tool(self, name: str, args: dict[str, Any], metadata: dict[str, Any] | None = None) -> Any: + return None + + def as_toolset(self, max_retries: int = 1) -> AbstractToolset: + return MCPToolset(self, max_retries=max_retries) # type: ignore + text_responses: dict[str, str | ToolCallPart] = { 'How many days between 2000-01-01 and 2025-03-18?': 'There are 9,208 days between January 1, 2000, and March 18, 2025.', diff --git a/tests/test_mcp.py b/tests/test_mcp.py index dd0a0a9c8..a780551c9 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,5 +1,7 @@ """Tests for the MCP (Model Context Protocol) server implementation.""" +from __future__ import annotations + import base64 import re from datetime import timezone @@ -23,8 +25,10 @@ ToolReturnPart, UserPromptPart, ) +from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext +from pydantic_ai.toolset import CallToolFunc from pydantic_ai.usage import Usage from .conftest import IsDatetime, IsNow, IsStr, try_import @@ -34,7 +38,7 @@ from mcp.types import CreateMessageRequestParams, ImageContent, TextContent from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response - from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult + from pydantic_ai.mcp import MCPServerSSE, MCPServerStdio, ToolResult from pydantic_ai.models.google import GoogleModel from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.google import GoogleProvider @@ -48,16 +52,29 @@ @pytest.fixture -def agent(openai_api_key: str): - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - return Agent(model, mcp_servers=[server]) +def mcp_server() -> MCPServerStdio: + return MCPServerStdio('python', ['-m', 'tests.mcp_server']) -async def test_stdio_server(): +@pytest.fixture +def model(openai_api_key: str) -> Model: + return OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + +@pytest.fixture +def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: + return Agent(model, mcp_servers=[mcp_server]) + + +@pytest.fixture +def run_context(model: Model) -> RunContext[None]: + return RunContext(deps=None, model=model, usage=Usage()) + + +async def test_stdio_server(run_context: RunContext[None]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: - tools = await server.list_tools() + tools = (await server.as_toolset().freeze_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) assert tools[0].name == 'celsius_to_fahrenheit' assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -74,34 +91,34 @@ async def test_reentrant_context_manager(): pass -async def test_stdio_server_with_tool_prefix(): +async def test_stdio_server_with_tool_prefix(run_context: RunContext[None]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') async with server: - tools = await server.list_tools() + tools = (await server.as_toolset().freeze_for_run(run_context)).tool_defs assert all(tool.name.startswith('foo_') for tool in tools) -async def test_stdio_server_with_cwd(): +async def test_stdio_server_with_cwd(run_context: RunContext[None]): test_dir = Path(__file__).parent server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: - tools = await server.list_tools() + tools = (await server.as_toolset().freeze_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) -async def test_process_tool_call() -> None: +async def test_process_tool_call(run_context: RunContext[None]) -> None: called: bool = False async def process_tool_call( - ctx: RunContext[int], + ctx: RunContext[None], call_tool: CallToolFunc, - tool_name: str, + name: str, args: dict[str, Any], ) -> ToolResult: """A process_tool_call that sets a flag and sends deps as metadata.""" nonlocal called called = True - return await call_tool(tool_name, args, {'deps': ctx.deps}) + return await call_tool(name, args, metadata={'deps': ctx.deps}) server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call) async with server: @@ -134,7 +151,7 @@ def test_sse_server_with_header_and_timeout(): @pytest.mark.vcr() async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') assert result.output == snapshot('0 degrees Celsius is equal to 32 degrees Fahrenheit.') assert result.all_messages() == snapshot( @@ -211,11 +228,11 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): with pytest.raises( UserError, match=re.escape( - "MCP Server 'MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None)' defines a tool whose name conflicts with existing tool: 'get_none'. Consider using `tool_prefix` to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from FunctionToolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." ), ): await agent.run('Get me a conflict') @@ -234,7 +251,7 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): # This means that we passed the _prepare_request_parameters check and there is no conflict in the tool name with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): await agent.run('No conflict') @@ -248,11 +265,11 @@ async def test_agent_with_server_not_running(openai_api_key: str): await agent.run('What is 0 degrees Celsius in Fahrenheit?') -async def test_log_level_unset(): +async def test_log_level_unset(run_context: RunContext[None]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) assert server.log_level is None async with server: - tools = await server.list_tools() + tools = (await server.as_toolset().freeze_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) assert tools[10].name == 'get_log_level' @@ -270,7 +287,7 @@ async def test_log_level_set(): @pytest.mark.vcr() async def test_tool_returning_str(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('What is the weather in Mexico City?') assert result.output == snapshot( 'The weather in Mexico City is currently sunny with a temperature of 26 degrees Celsius.' @@ -349,7 +366,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_text_resource(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Get me the product name') assert result.output == snapshot('The product name is "PydanticAI".') assert result.all_messages() == snapshot( @@ -422,7 +439,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A @pytest.mark.vcr() async def test_tool_returning_image_resource(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Get me the image resource') assert result.output == snapshot( 'This is an image of a sliced kiwi with a vibrant green interior and black seeds.' @@ -505,7 +522,7 @@ async def test_tool_returning_audio_resource( allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str ): model = GoogleModel('gemini-2.5-pro-preview-03-25', provider=GoogleProvider(api_key=gemini_api_key)) - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run("What's the content of the audio resource?", model=model) assert result.output == snapshot('The audio resource contains a voice saying "Hello, my name is Marcelo."') assert result.all_messages() == snapshot( @@ -556,7 +573,7 @@ async def test_tool_returning_audio_resource( @pytest.mark.vcr() async def test_tool_returning_image(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Get me an image') assert result.output == snapshot('Here is an image of a sliced kiwi on a white background.') assert result.all_messages() == snapshot( @@ -636,7 +653,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im @pytest.mark.vcr() async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Get me a dict, respond on one line') assert result.output == snapshot('{"foo":"bar","baz":123}') assert result.all_messages() == snapshot( @@ -703,7 +720,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_error(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Get me an error, pass False as a value, unless the tool tells you otherwise') assert result.output == snapshot( 'I called the tool with the correct parameter, and it returned: "This is not an error."' @@ -817,7 +834,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_none(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Call the none tool and say Hello') assert result.output == snapshot('Hello! How can I assist you today?') assert result.all_messages() == snapshot( @@ -884,7 +901,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_multiple_items(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Get me multiple items and summarize in one sentence') assert result.output == snapshot( 'The data includes two strings, a dictionary with a key-value pair, and an image of a sliced kiwi.' @@ -990,19 +1007,19 @@ async def test_client_sampling(): ) -async def test_mcp_server_raises_mcp_error(allow_model_requests: None, agent: Agent) -> None: - server = agent._mcp_servers[0] # pyright: ignore[reportPrivateUsage] - +async def test_mcp_server_raises_mcp_error( + allow_model_requests: None, mcp_server: MCPServerStdio, agent: Agent +) -> None: mcp_error = McpError(error=ErrorData(code=400, message='Test MCP error conversion')) - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): with patch.object( - server._client, # pyright: ignore[reportPrivateUsage] + mcp_server._client, # pyright: ignore[reportPrivateUsage] 'send_request', new=AsyncMock(side_effect=mcp_error), ): with pytest.raises(ModelRetry, match='Test MCP error conversion'): - await server.call_tool('test_tool', {}) + await mcp_server.call_tool('test_tool', {}) def test_map_from_mcp_params_model_request(): diff --git a/tests/test_tools.py b/tests/test_tools.py index e3eeedde0..9457ed060 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,6 +12,7 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext, Tool, UserError +from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel @@ -486,15 +487,15 @@ def plain_tool(x: int) -> int: result = agent.run_sync('foobar') assert result.output == snapshot('{"plain_tool":1}') assert call_args == snapshot([0]) - assert agent._function_tools['plain_tool'].takes_ctx is False - assert agent._function_tools['plain_tool'].max_retries == 7 + assert agent._function_toolset.tools['plain_tool'].takes_ctx is False + assert agent._function_toolset.tools['plain_tool'].max_retries == 7 agent_infer = Agent('test', tools=[plain_tool], retries=7) result = agent_infer.run_sync('foobar') assert result.output == snapshot('{"plain_tool":1}') assert call_args == snapshot([0, 0]) - assert agent_infer._function_tools['plain_tool'].takes_ctx is False - assert agent_infer._function_tools['plain_tool'].max_retries == 7 + assert agent_infer._function_toolset.tools['plain_tool'].takes_ctx is False + assert agent_infer._function_toolset.tools['plain_tool'].max_retries == 7 def ctx_tool(ctx: RunContext[int], x: int) -> int: @@ -506,13 +507,13 @@ def test_init_tool_ctx(): agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7) result = agent.run_sync('foobar', deps=5) assert result.output == snapshot('{"ctx_tool":5}') - assert agent._function_tools['ctx_tool'].takes_ctx is True - assert agent._function_tools['ctx_tool'].max_retries == 3 + assert agent._function_toolset.tools['ctx_tool'].takes_ctx is True + assert agent._function_toolset.tools['ctx_tool'].max_retries == 3 agent_infer = Agent('test', tools=[ctx_tool], deps_type=int) result = agent_infer.run_sync('foobar', deps=6) assert result.output == snapshot('{"ctx_tool":6}') - assert agent_infer._function_tools['ctx_tool'].takes_ctx is True + assert agent_infer._function_toolset.tools['ctx_tool'].takes_ctx is True def test_repeat_tool_by_rename(): @@ -562,7 +563,7 @@ def foo(x: int, y: str) -> str: # pragma: no cover def bar(x: int, y: str) -> str: # pragma: no cover return f'{x} {y}' - with pytest.raises(UserError, match=r"Tool name conflicts with existing tool: 'bar'."): + with pytest.raises(UserError, match="Tool name conflicts with previously renamed tool: 'bar'."): agent.run_sync('') @@ -572,7 +573,10 @@ def test_tool_return_conflict(): # this is also okay Agent('test', tools=[ctx_tool], deps_type=int, output_type=int) # this raises an error - with pytest.raises(UserError, match="Tool name conflicts with output tool name: 'ctx_tool'"): + with pytest.raises( + UserError, + match="FunctionToolset defines a tool whose name conflicts with existing tool from OutputToolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.", + ): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) @@ -1044,7 +1048,7 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str: with agent.override(model=FunctionModel(get_json_schema)): result = agent.run_sync('', deps=21) json_schema = json.loads(result.output) - assert agent._function_tools['foobar'].strict is None + assert agent._function_toolset.tools['foobar'].strict is None assert json_schema['strict'] is True result = agent.run_sync('', deps=1) @@ -1071,8 +1075,8 @@ def function(*args: Any, **kwargs: Any) -> str: agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') assert result.output == snapshot('{"foobar":"I like being called like this"}') - assert agent._function_tools['foobar'].takes_ctx is False - assert agent._function_tools['foobar'].max_retries == 0 + assert agent._function_toolset.tools['foobar'].takes_ctx is False + assert agent._function_toolset.tools['foobar'].max_retries == 0 def test_function_tool_inconsistent_with_schema(): @@ -1118,5 +1122,37 @@ async def function(*args: Any, **kwargs: Any) -> str: agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') assert result.output == snapshot('{"foobar":"I like being called like this"}') - assert agent._function_tools['foobar'].takes_ctx is False - assert agent._function_tools['foobar'].max_retries == 0 + assert agent._function_toolset.tools['foobar'].takes_ctx is False + assert agent._function_toolset.tools['foobar'].max_retries == 0 + + +def test_tool_retries(): + prepare_tools_retries: list[int] = [] + prepare_retries: list[int] = [] + call_retries: list[int] = [] + + async def prepare_tool_defs(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition] | None: + nonlocal prepare_tools_retries + retry = ctx.retries.get('infinite_retry_tool', 0) + prepare_tools_retries.append(retry) + return tool_defs + + agent = Agent(TestModel(), retries=3, prepare_tools=prepare_tool_defs) + + async def prepare_tool_def(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition | None: + nonlocal prepare_retries + prepare_retries.append(ctx.retry) + return tool_def + + @agent.tool(retries=5, prepare=prepare_tool_def) + def infinite_retry_tool(ctx: RunContext[None]) -> int: + nonlocal call_retries + call_retries.append(ctx.retry) + raise ModelRetry('Please try again.') + + with pytest.raises(UnexpectedModelBehavior, match='Tool exceeded max retries count of 5'): + agent.run_sync('Begin infinite retry loop!') + + assert prepare_tools_retries == [0, 1, 2, 3, 4, 5] + assert prepare_retries == [0, 1, 2, 3, 4, 5] + assert call_retries == [0, 1, 2, 3, 4, 5] diff --git a/tests/test_toolset.py b/tests/test_toolset.py new file mode 100644 index 000000000..120ce07d0 --- /dev/null +++ b/tests/test_toolset.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass, replace +from typing import TypeVar + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai._run_context import RunContext +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolset import FunctionToolset +from pydantic_ai.usage import Usage + +pytestmark = pytest.mark.anyio + +T = TypeVar('T') + + +def build_run_context(deps: T) -> RunContext[T]: + return RunContext(deps=deps, model=TestModel(), usage=Usage(), prompt=None, messages=[], run_step=0) + + +async def test_function_toolset_freeze_for_run(): + @dataclass + class PrefixDeps: + prefix: str | None = None + + context = build_run_context(PrefixDeps()) + toolset = FunctionToolset[PrefixDeps]() + + async def prepare_add_prefix(ctx: RunContext[PrefixDeps], tool_def: ToolDefinition) -> ToolDefinition | None: + if ctx.deps.prefix is None: + return tool_def + + return replace(tool_def, name=f'{ctx.deps.prefix}_{tool_def.name}') + + @toolset.tool(prepare=prepare_add_prefix) + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + assert toolset.tool_names == snapshot(['add']) + assert toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ) + ] + ) + assert await toolset.call_tool(context, 'add', {'a': 1, 'b': 2}) == 3 + + foo_context = build_run_context(PrefixDeps(prefix='foo')) + foo_toolset = await toolset.freeze_for_run(foo_context) + assert foo_toolset.tool_names == snapshot(['foo_add']) + assert foo_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='foo_add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ) + ] + ) + assert await foo_toolset.call_tool(foo_context, 'add', {'a': 1, 'b': 2}) == 3 + + @toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b + + assert foo_toolset.tool_names == snapshot(['foo_add']) + + bar_context = build_run_context(PrefixDeps(prefix='bar')) + bar_toolset = await toolset.freeze_for_run(bar_context) + assert bar_toolset.tool_names == snapshot(['bar_add', 'subtract']) + assert bar_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='bar_add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='subtract', + description='Subtract two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ] + ) + assert await bar_toolset.call_tool(bar_context, 'add', {'a': 1, 'b': 2}) == 3 + + bar_foo_toolset = await foo_toolset.freeze_for_run(bar_context) + assert bar_foo_toolset == bar_toolset From 0f8da7450fb6ca2103e7075aa967ce2ecd154955 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 20 Jun 2025 23:57:34 +0000 Subject: [PATCH 47/90] Make MCPServer a Toolset --- pydantic_ai_slim/pydantic_ai/agent.py | 10 ++- pydantic_ai_slim/pydantic_ai/mcp.py | 74 ++++++++++++++++---- pydantic_ai_slim/pydantic_ai/toolset.py | 92 +++++-------------------- tests/test_examples.py | 23 +++++-- tests/test_mcp.py | 40 +++++------ 5 files changed, 116 insertions(+), 123 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 0c508a2da..21668fbd9 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -235,7 +235,9 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), + mcp_servers: Sequence[ + MCPServer + ] = (), # TODO: Deprecate argument, MCPServers can be passed directly to toolsets toolsets: Sequence[AbstractToolset[AgentDepsT]] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', @@ -364,11 +366,7 @@ def __init__( # This will raise errors for any name conflicts CombinedToolset[AgentDepsT]([self._output_toolset, self._function_toolset]) - mcp_toolsets = [ - cast(AbstractToolset[AgentDepsT], mcp_server.as_toolset(max_retries=self._max_result_retries)) - for mcp_server in mcp_servers - ] - toolset = CombinedToolset[AgentDepsT]([self._function_toolset, *toolsets, *mcp_toolsets]) + toolset = CombinedToolset[AgentDepsT]([self._function_toolset, *toolsets, *mcp_servers]) if prepare_tools: toolset = PreparedToolset[AgentDepsT](toolset, prepare_tools) self._toolset = toolset diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 9574a830a..b8f07fb87 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -16,8 +16,11 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from typing_extensions import Self, assert_never, deprecated +from pydantic_ai._run_context import RunContext +from pydantic_ai.tools import ToolDefinition + from .exceptions import UserError -from .toolset import AbstractToolset, MCPToolset, PrefixedToolset, ProcessedToolset, ToolProcessFunc +from .toolset import AbstractToolset, FrozenToolset, PrefixedToolset, ProcessedToolset, ToolProcessFunc try: from mcp import types as mcp_types @@ -40,7 +43,7 @@ __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP' -class MCPServer(ABC): +class MCPServer(AbstractToolset[Any], ABC): """Base class for attaching agents to MCP servers. See for more information. @@ -51,7 +54,7 @@ class MCPServer(ABC): log_level: mcp_types.LoggingLevel | None = None log_handler: LoggingFnT | None = None init_timeout: float = 5 - process_tool_call: ToolProcessFunc | None = None + process_tool_call: ToolProcessFunc[Any] | None = None # } end of "abstract fields" _running_count: int = 0 @@ -81,6 +84,14 @@ def is_running(self) -> bool: """Check if the MCP server is running.""" return bool(self._running_count) + @property + def name(self) -> str: + return repr(self) + + @property + def name_conflict_hint(self) -> str: + return 'Consider setting `tool_prefix` to avoid name conflicts.' + async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. @@ -95,16 +106,22 @@ async def list_tools(self) -> list[mcp_types.Tool]: async def call_tool( self, - tool_name: str, - arguments: dict[str, Any], + ctx: RunContext[Any], + name: str, + tool_args: dict[str, Any], + *args: Any, metadata: dict[str, Any] | None = None, + **kwargs: Any, ) -> ToolResult: """Call a tool on the server. Args: - tool_name: The name of the tool to call. - arguments: The arguments to pass to the tool. + ctx: The run context of the tool call. + name: The name of the tool to call. + tool_args: The arguments to pass to the tool. + *args: Additional arguments passed by a tool call processor. metadata: Request-level metadata (optional) + **kwargs: Additional keyword arguments passed by a tool call processor. Returns: The result of the tool call. @@ -121,8 +138,8 @@ async def call_tool( mcp_types.CallToolRequest( method='tools/call', params=mcp_types.CallToolRequestParams( - name=tool_name, - arguments=arguments, + name=name, + arguments=tool_args, _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, ), ) @@ -140,13 +157,42 @@ async def call_tool( else: return content[0] if len(content) == 1 else content - def as_toolset(self, max_retries: int = 1) -> AbstractToolset: - toolset = MCPToolset(self, max_retries=max_retries) + async def freeze_for_run(self, ctx: RunContext[Any]) -> FrozenToolset[Any]: + frozen_self = FrozenToolset[Any](self, ctx, await self.list_tool_defs()) + toolset = frozen_self if self.process_tool_call: toolset = ProcessedToolset(toolset, self.process_tool_call) if self.tool_prefix: toolset = PrefixedToolset(toolset, self.tool_prefix) - return toolset + return FrozenToolset[Any](toolset, ctx, original=self) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [] + + async def list_tool_defs(self) -> list[ToolDefinition]: + mcp_tools = await self.list_tools() + return [ + ToolDefinition( + name=mcp_tool.name, + description=mcp_tool.description or '', + parameters_json_schema=mcp_tool.inputSchema, + ) + for mcp_tool in mcp_tools + ] + + def get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_core.SchemaValidator: + return pydantic_core.SchemaValidator( + schema=pydantic_core.core_schema.dict_schema( + pydantic_core.core_schema.str_schema(), pydantic_core.core_schema.any_schema() + ) + ) + + def max_retries_for_tool(self, name: str) -> int: + return 1 + + def set_mcp_sampling_model(self, model: models.Model) -> None: + self.sampling_model = model async def __aenter__(self) -> Self: if self._running_count == 0: @@ -323,7 +369,7 @@ async def main(): timeout: float = 5 """ The timeout in seconds to wait for the client to initialize.""" - process_tool_call: ToolProcessFunc | None = None + process_tool_call: ToolProcessFunc[Any] | None = None """Hook to customize tool calling and optionally pass extra metadata.""" @asynccontextmanager @@ -418,7 +464,7 @@ class _MCPServerHTTP(MCPServer): init_timeout: float = 5 """The timeout in seconds to wait for the client to initialize.""" - process_tool_call: ToolProcessFunc | None = None + process_tool_call: ToolProcessFunc[Any] | None = None """Hook to customize tool calling and optionally pass extra metadata.""" @property diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index bd58c53b1..1b4c33b61 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -11,7 +11,7 @@ from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema -from pydantic_core import SchemaValidator, core_schema +from pydantic_core import SchemaValidator from typing_extensions import Self from ._output import BaseOutputSchema @@ -29,7 +29,6 @@ ) if TYPE_CHECKING: - from pydantic_ai.mcp import MCPServer from pydantic_ai.models import Model @@ -98,7 +97,7 @@ def set_mcp_sampling_model(self, model: Model) -> None: @dataclass(init=False) -class _BareFunctionToolset(AbstractToolset[AgentDepsT]): +class FunctionToolset(AbstractToolset[AgentDepsT]): """A toolset that functions can be registered to as tools.""" max_retries: int = field(default=1) @@ -259,6 +258,16 @@ def register_tool(self, tool: Tool[AgentDepsT]) -> None: tool.max_retries = self.max_retries self.tools[tool.name] = tool + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + frozen_self = FrozenToolset[AgentDepsT](self, ctx) + toolset = await IndividuallyPreparedToolset(frozen_self, self._prepare_tool_def).freeze_for_run(ctx) + return FrozenToolset[AgentDepsT](toolset, ctx, original=self) + + async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: + tool_name = tool_def.name + ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) + return await self.tools[tool_name].prepare_tool_def(ctx) + @property def tool_defs(self) -> list[ToolDefinition]: return [tool.tool_def for tool in self.tools.values()] @@ -301,63 +310,6 @@ async def call_tool( return await self.output_schema.tools[name].processor.process(tool_args, ctx) -@dataclass -class MCPToolset(AbstractToolset[AgentDepsT]): - """A toolset that contains MCP tools and handles running the server.""" - - server: MCPServer - max_retries: int = field(default=1) - - @property - def name(self) -> str: - return repr(self.server) - - @property - def name_conflict_hint(self) -> str: - return 'Consider setting `tool_prefix` to avoid name conflicts.' - - async def __aenter__(self) -> Self: - await self.server.__aenter__() - return self - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None - ) -> bool | None: - return await self.server.__aexit__(exc_type, exc_value, traceback) - - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - return FrozenToolset[AgentDepsT](self, ctx, await self.list_tool_defs()) - - @property - def tool_defs(self) -> list[ToolDefinition]: - return [] - - async def list_tool_defs(self) -> list[ToolDefinition]: - mcp_tools = await self.server.list_tools() - return [ - ToolDefinition( - name=mcp_tool.name, - description=mcp_tool.description or '', - parameters_json_schema=mcp_tool.inputSchema, - ) - for mcp_tool in mcp_tools - ] - - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return SchemaValidator(schema=core_schema.dict_schema(core_schema.str_schema(), core_schema.any_schema())) - - def max_retries_for_tool(self, name: str) -> int: - return self.max_retries - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], metadata: dict[str, Any] | None = None - ) -> Any: - return await self.server.call_tool(name, tool_args, metadata) - - def set_mcp_sampling_model(self, model: Model) -> None: - self.server.sampling_model = model - - @dataclass class WrapperToolset(AbstractToolset[AgentDepsT]): """A toolset that wraps another toolset and delegates to it.""" @@ -569,21 +521,6 @@ async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[Age return FrozenToolset[AgentDepsT](freezable_prepared, ctx) -@dataclass -class FunctionToolset(IndividuallyPreparedToolset[AgentDepsT]): - """A toolset that contains function tools.""" - - def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): - wrapped = _BareFunctionToolset(tools, max_retries) - - async def prepare_tool_def(ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: - tool_name = tool_def.name - ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) - return await wrapped.tools[tool_name].prepare_tool_def(ctx) - - super().__init__(wrapped, prepare_tool_def) - - @dataclass(init=False) class _FreezableIndividuallyPreparedToolset(IndividuallyPreparedToolset[AgentDepsT]): name_map: dict[str, str] @@ -676,18 +613,21 @@ class FrozenToolset(WrapperToolset[AgentDepsT]): _tool_defs: list[ToolDefinition] _tool_names: list[str] _retries: dict[str, int] + _original: AbstractToolset[AgentDepsT] def __init__( self, wrapped: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT], tool_defs: list[ToolDefinition] | None = None, + original: AbstractToolset[AgentDepsT] | None = None, ): self.wrapped = wrapped self.ctx = ctx self._tool_defs = wrapped.tool_defs if tool_defs is None else tool_defs self._tool_names = [tool_def.name for tool_def in self._tool_defs] self._retries = ctx.retries.copy() + self._original = original or wrapped @property def name(self) -> str: @@ -698,7 +638,7 @@ async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[Age return self else: ctx = replace(ctx, retries=self._retries) - return await self.wrapped.freeze_for_run(ctx) + return await self._original.freeze_for_run(ctx) @property def tool_defs(self) -> list[ToolDefinition]: diff --git a/tests/test_examples.py b/tests/test_examples.py index 30e114043..9010ac807 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -16,11 +16,13 @@ import pytest from _pytest.mark import ParameterSet from devtools import debug +from pydantic_core import SchemaValidator, core_schema from pytest_examples import CodeExample, EvalExample, find_examples from pytest_mock import MockerFixture from rich.console import Console from pydantic_ai import ModelHTTPError +from pydantic_ai._run_context import RunContext from pydantic_ai._utils import group_by_temporal from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( @@ -36,7 +38,8 @@ from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.toolset import AbstractToolset, MCPToolset +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolset import AbstractToolset from .conftest import ClientWithHandler, TestEnv, try_import @@ -253,7 +256,7 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: raise ValueError(f'Unexpected prompt: {prompt}') -class MockMCPServer: +class MockMCPServer(AbstractToolset[Any]): is_running = True async def __aenter__(self) -> MockMCPServer: @@ -262,14 +265,20 @@ async def __aenter__(self) -> MockMCPServer: async def __aexit__(self, *args: Any) -> None: pass - async def list_tools(self) -> list[None]: + @property + def tool_defs(self) -> list[ToolDefinition]: return [] - async def call_tool(self, name: str, args: dict[str, Any], metadata: dict[str, Any] | None = None) -> Any: - return None + def get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator: + return SchemaValidator(core_schema.any_schema()) + + def max_retries_for_tool(self, name: str) -> int: + return 0 - def as_toolset(self, max_retries: int = 1) -> AbstractToolset: - return MCPToolset(self, max_retries=max_retries) # type: ignore + async def call_tool( + self, ctx: RunContext[Any], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return None text_responses: dict[str, str | ToolCallPart] = { diff --git a/tests/test_mcp.py b/tests/test_mcp.py index a780551c9..05c51c213 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -67,20 +67,20 @@ def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: @pytest.fixture -def run_context(model: Model) -> RunContext[None]: - return RunContext(deps=None, model=model, usage=Usage()) +def run_context(model: Model) -> RunContext[int]: + return RunContext(deps=0, model=model, usage=Usage()) -async def test_stdio_server(run_context: RunContext[None]): +async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: - tools = (await server.as_toolset().freeze_for_run(run_context)).tool_defs + tools = (await server.freeze_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) assert tools[0].name == 'celsius_to_fahrenheit' assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') # Test calling the temperature conversion tool - result = await server.call_tool('celsius_to_fahrenheit', {'celsius': 0}) + result = await server.call_tool(run_context, 'celsius_to_fahrenheit', {'celsius': 0}) assert result == snapshot('32.0') @@ -91,26 +91,26 @@ async def test_reentrant_context_manager(): pass -async def test_stdio_server_with_tool_prefix(run_context: RunContext[None]): +async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') async with server: - tools = (await server.as_toolset().freeze_for_run(run_context)).tool_defs + tools = (await server.freeze_for_run(run_context)).tool_defs assert all(tool.name.startswith('foo_') for tool in tools) -async def test_stdio_server_with_cwd(run_context: RunContext[None]): +async def test_stdio_server_with_cwd(run_context: RunContext[int]): test_dir = Path(__file__).parent server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: - tools = (await server.as_toolset().freeze_for_run(run_context)).tool_defs + tools = (await server.freeze_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) -async def test_process_tool_call(run_context: RunContext[None]) -> None: +async def test_process_tool_call(run_context: RunContext[int]) -> int: called: bool = False async def process_tool_call( - ctx: RunContext[None], + ctx: RunContext[int], call_tool: CallToolFunc, name: str, args: dict[str, Any], @@ -265,23 +265,23 @@ async def test_agent_with_server_not_running(openai_api_key: str): await agent.run('What is 0 degrees Celsius in Fahrenheit?') -async def test_log_level_unset(run_context: RunContext[None]): +async def test_log_level_unset(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) assert server.log_level is None async with server: - tools = (await server.as_toolset().freeze_for_run(run_context)).tool_defs + tools = (await server.freeze_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) assert tools[10].name == 'get_log_level' - result = await server.call_tool('get_log_level', {}) + result = await server.call_tool(run_context, 'get_log_level', {}) assert result == snapshot('unset') -async def test_log_level_set(): +async def test_log_level_set(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], log_level='info') assert server.log_level == 'info' async with server: - result = await server.call_tool('get_log_level', {}) + result = await server.call_tool(run_context, 'get_log_level', {}) assert result == snapshot('info') @@ -990,12 +990,12 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ) -async def test_client_sampling(): +async def test_client_sampling(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], log_level='info') server.sampling_model = TestModel(custom_output_text='sampling model response') assert server.log_level == 'info' async with server: - result = await server.call_tool('use_sampling', {'foo': 'bar'}) + result = await server.call_tool(run_context, 'use_sampling', {'foo': 'bar'}) assert result == snapshot( { 'meta': None, @@ -1008,7 +1008,7 @@ async def test_client_sampling(): async def test_mcp_server_raises_mcp_error( - allow_model_requests: None, mcp_server: MCPServerStdio, agent: Agent + allow_model_requests: None, mcp_server: MCPServerStdio, agent: Agent, run_context: RunContext[int] ) -> None: mcp_error = McpError(error=ErrorData(code=400, message='Test MCP error conversion')) @@ -1019,7 +1019,7 @@ async def test_mcp_server_raises_mcp_error( new=AsyncMock(side_effect=mcp_error), ): with pytest.raises(ModelRetry, match='Test MCP error conversion'): - await mcp_server.call_tool('test_tool', {}) + await mcp_server.call_tool(run_context, 'test_tool', {}) def test_map_from_mcp_params_model_request(): From 3d2012ce1d39a0d436594a616691f744e697743d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Sat, 21 Jun 2025 00:22:08 +0000 Subject: [PATCH 48/90] Add MappedToolset --- pydantic_ai_slim/pydantic_ai/mcp.py | 10 +-- pydantic_ai_slim/pydantic_ai/toolset.py | 100 ++++++++++++------------ 2 files changed, 57 insertions(+), 53 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ad51a4bcb..7a014a878 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -159,13 +159,13 @@ async def call_tool( return content[0] if len(content) == 1 else content async def freeze_for_run(self, ctx: RunContext[Any]) -> FrozenToolset[Any]: - frozen_self = FrozenToolset[Any](self, ctx, await self.list_tool_defs()) - toolset = frozen_self + frozen_self = FrozenToolset(self, ctx, await self.list_tool_defs()) + frozen_toolset = frozen_self if self.process_tool_call: - toolset = ProcessedToolset(toolset, self.process_tool_call) + frozen_toolset = await ProcessedToolset(frozen_toolset, self.process_tool_call).freeze_for_run(ctx) if self.tool_prefix: - toolset = PrefixedToolset(toolset, self.tool_prefix) - return FrozenToolset[Any](toolset, ctx, original=self) + frozen_toolset = await PrefixedToolset(frozen_toolset, self.tool_prefix).freeze_for_run(ctx) + return FrozenToolset(frozen_toolset, ctx, original=self) @property def tool_defs(self) -> list[ToolDefinition]: diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index 1b4c33b61..c86625867 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -58,7 +58,7 @@ async def __aexit__( return None async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - return FrozenToolset[AgentDepsT](self, ctx) + return FrozenToolset(self, ctx) @property @abstractmethod @@ -259,9 +259,9 @@ def register_tool(self, tool: Tool[AgentDepsT]) -> None: self.tools[tool.name] = tool async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_self = FrozenToolset[AgentDepsT](self, ctx) - toolset = await IndividuallyPreparedToolset(frozen_self, self._prepare_tool_def).freeze_for_run(ctx) - return FrozenToolset[AgentDepsT](toolset, ctx, original=self) + frozen_self = FrozenToolset(self, ctx) + frozen_prepared = await IndividuallyPreparedToolset(frozen_self, self._prepare_tool_def).freeze_for_run(ctx) + return FrozenToolset(frozen_prepared, ctx, original=self) async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: tool_name = tool_def.name @@ -399,8 +399,8 @@ async def __aexit__( async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: frozen_toolsets = await asyncio.gather(*[toolset.freeze_for_run(ctx) for toolset in self.toolsets]) - freezable_combined = CombinedToolset[AgentDepsT](frozen_toolsets) - return FrozenToolset[AgentDepsT](freezable_combined, ctx) + freezable_combined = CombinedToolset(frozen_toolsets) + return FrozenToolset(freezable_combined, ctx) @property def tool_defs(self) -> list[ToolDefinition]: @@ -440,8 +440,8 @@ class PrefixedToolset(WrapperToolset[AgentDepsT]): async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: frozen_wrapped = await self.wrapped.freeze_for_run(ctx) - freezable_prefixed = PrefixedToolset[AgentDepsT](frozen_wrapped, self.prefix) - return FrozenToolset[AgentDepsT](freezable_prefixed, ctx) + freezable_prefixed = PrefixedToolset(frozen_wrapped, self.prefix) + return FrozenToolset(freezable_prefixed, ctx) @property def tool_defs(self) -> list[ToolDefinition]: @@ -484,59 +484,32 @@ async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[Age if len(prepared_tool_names - original_tool_names) > 0: raise UserError('Prepare function is not allowed to change tool names or add new tools.') - freezable_prepared = PreparedToolset[AgentDepsT](frozen_wrapped, self.prepare_func) - return FrozenToolset[AgentDepsT](freezable_prepared, ctx, prepared_tool_defs) - - -@dataclass -class IndividuallyPreparedToolset(WrapperToolset[AgentDepsT]): - """A toolset that prepares the tools it contains using a per-tool prepare function.""" - - prepare_func: ToolPrepareFunc[AgentDepsT] - - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_wrapped = await self.wrapped.freeze_for_run(ctx) - - tool_defs: dict[str, ToolDefinition] = {} - name_map: dict[str, str] = {} - for original_tool_def in frozen_wrapped.tool_defs: - original_name = original_tool_def.name - tool_def = await self.prepare_func(ctx, original_tool_def) - if not tool_def: - continue - - new_name = tool_def.name - if new_name in tool_defs: - if new_name != original_name: - raise UserError(f"Renaming tool '{original_name}' to '{new_name}' conflicts with existing tool.") - else: - raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') - name_map[new_name] = original_name - - tool_defs[new_name] = tool_def - - freezable_prepared = _FreezableIndividuallyPreparedToolset( - frozen_wrapped, self.prepare_func, list(tool_defs.values()), name_map - ) - return FrozenToolset[AgentDepsT](freezable_prepared, ctx) + freezable_prepared = PreparedToolset(frozen_wrapped, self.prepare_func) + return FrozenToolset(freezable_prepared, ctx, prepared_tool_defs) @dataclass(init=False) -class _FreezableIndividuallyPreparedToolset(IndividuallyPreparedToolset[AgentDepsT]): +class MappedToolset(WrapperToolset[AgentDepsT]): + """A toolset that maps the names of the tools it contains.""" + name_map: dict[str, str] _tool_defs: list[ToolDefinition] def __init__( self, wrapped: AbstractToolset[AgentDepsT], - prepare_func: ToolPrepareFunc[AgentDepsT], tool_defs: list[ToolDefinition], name_map: dict[str, str], ): - super().__init__(wrapped, prepare_func) + super().__init__(wrapped) self._tool_defs = tool_defs self.name_map = name_map + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + frozen_wrapped = await self.wrapped.freeze_for_run(ctx) + freezable_mapped = MappedToolset(frozen_wrapped, self._tool_defs, self.name_map) + return FrozenToolset(freezable_mapped, ctx) + @property def tool_defs(self) -> list[ToolDefinition]: return self._tool_defs @@ -556,6 +529,37 @@ def _map_name(self, name: str) -> str: return self.name_map.get(name, name) +@dataclass +class IndividuallyPreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a per-tool prepare function.""" + + prepare_func: ToolPrepareFunc[AgentDepsT] + + async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + frozen_wrapped = await self.wrapped.freeze_for_run(ctx) + + tool_defs: dict[str, ToolDefinition] = {} + name_map: dict[str, str] = {} + for original_tool_def in frozen_wrapped.tool_defs: + original_name = original_tool_def.name + tool_def = await self.prepare_func(ctx, original_tool_def) + if not tool_def: + continue + + new_name = tool_def.name + if new_name in tool_defs: + if new_name != original_name: + raise UserError(f"Renaming tool '{original_name}' to '{new_name}' conflicts with existing tool.") + else: + raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') + name_map[new_name] = original_name + + tool_defs[new_name] = tool_def + + frozen_mapped = await MappedToolset(frozen_wrapped, list(tool_defs.values()), name_map).freeze_for_run(ctx) + return FrozenToolset(frozen_mapped, ctx, original=self) + + @dataclass(init=False) class FilteredToolset(IndividuallyPreparedToolset[AgentDepsT]): """A toolset that filters the tools it contains using a filter function.""" @@ -596,8 +600,8 @@ class ProcessedToolset(WrapperToolset[AgentDepsT]): async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: frozen_wrapped = await self.wrapped.freeze_for_run(ctx) - processed = ProcessedToolset[AgentDepsT](frozen_wrapped, self.process) - return FrozenToolset[AgentDepsT](processed, ctx) + processed = ProcessedToolset(frozen_wrapped, self.process) + return FrozenToolset(processed, ctx) async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any From 901267d6f083d51f532b6cdc976c25dfabcc0098 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Sat, 21 Jun 2025 00:24:51 +0000 Subject: [PATCH 49/90] Import Never from typing_extensions instead of typing --- pydantic_ai_slim/pydantic_ai/toolset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index c86625867..8ef792b8a 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -7,12 +7,12 @@ from dataclasses import dataclass, field, replace from functools import partial from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Generic, Never, Protocol, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, overload from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema from pydantic_core import SchemaValidator -from typing_extensions import Self +from typing_extensions import Never, Self from ._output import BaseOutputSchema from ._run_context import AgentDepsT, RunContext From b9258d7e62c70b6eab1c9a90960a01799f03adb7 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Sat, 21 Jun 2025 00:28:40 +0000 Subject: [PATCH 50/90] from __future__ import annotations --- tests/test_tools.py | 6 ++++-- tests/test_toolset.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 9457ed060..0f456b5a9 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1131,7 +1131,9 @@ def test_tool_retries(): prepare_retries: list[int] = [] call_retries: list[int] = [] - async def prepare_tool_defs(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition] | None: + async def prepare_tool_defs( + ctx: RunContext[None], tool_defs: list[ToolDefinition] + ) -> Union[list[ToolDefinition], None]: nonlocal prepare_tools_retries retry = ctx.retries.get('infinite_retry_tool', 0) prepare_tools_retries.append(retry) @@ -1139,7 +1141,7 @@ async def prepare_tool_defs(ctx: RunContext[None], tool_defs: list[ToolDefinitio agent = Agent(TestModel(), retries=3, prepare_tools=prepare_tool_defs) - async def prepare_tool_def(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition | None: + async def prepare_tool_def(ctx: RunContext[None], tool_def: ToolDefinition) -> Union[ToolDefinition, None]: nonlocal prepare_retries prepare_retries.append(ctx.retry) return tool_def diff --git a/tests/test_toolset.py b/tests/test_toolset.py index 120ce07d0..b59d7b321 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, replace from typing import TypeVar From 27ccbd1d1eac42add55e994f4a741a002b39629b Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 20 Jun 2025 19:06:24 -0600 Subject: [PATCH 51/90] Update client.md --- docs/mcp/client.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 60e6c4126..fcd2bea29 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -364,7 +364,7 @@ agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent.run_toolsets(): result = await agent.run('Create an image of a robot in a punk style.') print(result.output) #> Image file written to robot_punk.svg. From 3031e550ac3a7917fbef842332a5048ec77db3d9 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Sat, 21 Jun 2025 01:38:23 +0000 Subject: [PATCH 52/90] Pass only RunToolset to agent graph --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 10 +-- pydantic_ai_slim/pydantic_ai/agent.py | 21 ++++-- pydantic_ai_slim/pydantic_ai/mcp.py | 12 ++-- pydantic_ai_slim/pydantic_ai/toolset.py | 72 ++++++++++---------- tests/test_examples.py | 6 +- tests/test_mcp.py | 8 +-- tests/test_tools.py | 5 +- tests/test_toolset.py | 16 +++-- 8 files changed, 85 insertions(+), 65 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index bfa2cc8b4..12b241d5c 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -16,7 +16,7 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._utils import is_async_callable, run_in_executor -from pydantic_ai.toolset import AbstractToolset, CombinedToolset +from pydantic_ai.toolset import CombinedToolset, RunToolset from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT @@ -106,12 +106,12 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] output_schema: _output.OutputSchema[OutputDataT] - output_toolset: AbstractToolset[DepsT] + output_toolset: RunToolset[DepsT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] history_processors: Sequence[HistoryProcessor[DepsT]] - toolset: AbstractToolset[DepsT] + toolset: RunToolset[DepsT] tracer: Tracer @@ -245,8 +245,8 @@ async def _prepare_request_parameters( ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" run_context = build_run_context(ctx) - ctx.deps.toolset = toolset = await ctx.deps.toolset.freeze_for_run(run_context) - ctx.deps.output_toolset = output_toolset = await ctx.deps.output_toolset.freeze_for_run(run_context) + ctx.deps.toolset = toolset = await ctx.deps.toolset.prepare_for_run(run_context) + ctx.deps.output_toolset = output_toolset = await ctx.deps.output_toolset.prepare_for_run(run_context) # This will raise errors for any name conflicts CombinedToolset[DepsT]([output_toolset, toolset]) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 21668fbd9..c3a5ed230 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -654,8 +654,6 @@ async def main(): output_toolset = self._output_toolset if output_schema != self._output_schema: output_toolset = OutputToolset[AgentDepsT](output_schema, max_retries=self._max_result_retries) - # This will raise errors for any name conflicts - CombinedToolset[AgentDepsT]([output_toolset, self._toolset]) # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( @@ -671,6 +669,21 @@ async def main(): run_step=0, ) + run_context = RunContext[AgentDepsT]( + deps=deps, + model=model_used, + usage=usage, + prompt=user_prompt, + messages=state.message_history, + run_step=state.run_step, + ) + + run_toolset = await self._toolset.prepare_for_run(run_context) + run_output_toolset = await output_toolset.prepare_for_run(run_context) + + # This will raise errors for any name conflicts + CombinedToolset([run_output_toolset, run_toolset]) + # We consider it a user error if a user tries to restrict the result type while having an output validator that # may change the result type from the restricted type to something else. Therefore, we consider the following # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. @@ -722,9 +735,9 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: end_strategy=self.end_strategy, output_schema=output_schema, output_validators=output_validators, - output_toolset=output_toolset, + output_toolset=run_output_toolset, history_processors=self.history_processors, - toolset=self._toolset, + toolset=run_toolset, tracer=tracer, get_instructions=get_instructions, ) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 7a014a878..33c055563 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -20,7 +20,7 @@ from pydantic_ai.tools import ToolDefinition from .exceptions import UserError -from .toolset import AbstractToolset, FrozenToolset, PrefixedToolset, ProcessedToolset, ToolProcessFunc +from .toolset import AbstractToolset, PrefixedToolset, ProcessedToolset, RunToolset, ToolProcessFunc try: from mcp import types as mcp_types @@ -158,14 +158,14 @@ async def call_tool( else: return content[0] if len(content) == 1 else content - async def freeze_for_run(self, ctx: RunContext[Any]) -> FrozenToolset[Any]: - frozen_self = FrozenToolset(self, ctx, await self.list_tool_defs()) + async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]: + frozen_self = RunToolset(self, ctx, await self.list_tool_defs()) frozen_toolset = frozen_self if self.process_tool_call: - frozen_toolset = await ProcessedToolset(frozen_toolset, self.process_tool_call).freeze_for_run(ctx) + frozen_toolset = await ProcessedToolset(frozen_toolset, self.process_tool_call).prepare_for_run(ctx) if self.tool_prefix: - frozen_toolset = await PrefixedToolset(frozen_toolset, self.tool_prefix).freeze_for_run(ctx) - return FrozenToolset(frozen_toolset, ctx, original=self) + frozen_toolset = await PrefixedToolset(frozen_toolset, self.tool_prefix).prepare_for_run(ctx) + return RunToolset(frozen_toolset, ctx, original=self) @property def tool_defs(self) -> list[ToolDefinition]: diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index 8ef792b8a..aa433275b 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -57,8 +57,8 @@ async def __aexit__( ) -> bool | None: return None - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - return FrozenToolset(self, ctx) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + return RunToolset(self, ctx) @property @abstractmethod @@ -258,10 +258,10 @@ def register_tool(self, tool: Tool[AgentDepsT]) -> None: tool.max_retries = self.max_retries self.tools[tool.name] = tool - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_self = FrozenToolset(self, ctx) - frozen_prepared = await IndividuallyPreparedToolset(frozen_self, self._prepare_tool_def).freeze_for_run(ctx) - return FrozenToolset(frozen_prepared, ctx, original=self) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + self_for_run = RunToolset(self, ctx) + prepared_for_run = await IndividuallyPreparedToolset(self_for_run, self._prepare_tool_def).prepare_for_run(ctx) + return RunToolset(prepared_for_run, ctx, original=self) async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: tool_name = tool_def.name @@ -333,7 +333,7 @@ async def __aexit__( ) -> bool | None: return await self.wrapped.__aexit__(exc_type, exc_value, traceback) - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: raise NotImplementedError() @property @@ -397,10 +397,10 @@ async def __aexit__( self._exit_stack = None return None - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_toolsets = await asyncio.gather(*[toolset.freeze_for_run(ctx) for toolset in self.toolsets]) - freezable_combined = CombinedToolset(frozen_toolsets) - return FrozenToolset(freezable_combined, ctx) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + toolsets_for_run = await asyncio.gather(*[toolset.prepare_for_run(ctx) for toolset in self.toolsets]) + combined_for_run = CombinedToolset(toolsets_for_run) + return RunToolset(combined_for_run, ctx) @property def tool_defs(self) -> list[ToolDefinition]: @@ -438,10 +438,10 @@ class PrefixedToolset(WrapperToolset[AgentDepsT]): prefix: str - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_wrapped = await self.wrapped.freeze_for_run(ctx) - freezable_prefixed = PrefixedToolset(frozen_wrapped, self.prefix) - return FrozenToolset(freezable_prefixed, ctx) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + prefixed_for_run = PrefixedToolset(wrapped_for_run, self.prefix) + return RunToolset(prefixed_for_run, ctx) @property def tool_defs(self) -> list[ToolDefinition]: @@ -474,9 +474,9 @@ class PreparedToolset(WrapperToolset[AgentDepsT]): prepare_func: ToolsPrepareFunc[AgentDepsT] - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_wrapped = await self.wrapped.freeze_for_run(ctx) - original_tool_defs = frozen_wrapped.tool_defs + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + original_tool_defs = wrapped_for_run.tool_defs prepared_tool_defs = await self.prepare_func(ctx, original_tool_defs) or [] original_tool_names = {tool_def.name for tool_def in original_tool_defs} @@ -484,8 +484,8 @@ async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[Age if len(prepared_tool_names - original_tool_names) > 0: raise UserError('Prepare function is not allowed to change tool names or add new tools.') - freezable_prepared = PreparedToolset(frozen_wrapped, self.prepare_func) - return FrozenToolset(freezable_prepared, ctx, prepared_tool_defs) + prepared_for_run = PreparedToolset(wrapped_for_run, self.prepare_func) + return RunToolset(prepared_for_run, ctx, prepared_tool_defs) @dataclass(init=False) @@ -505,10 +505,10 @@ def __init__( self._tool_defs = tool_defs self.name_map = name_map - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_wrapped = await self.wrapped.freeze_for_run(ctx) - freezable_mapped = MappedToolset(frozen_wrapped, self._tool_defs, self.name_map) - return FrozenToolset(freezable_mapped, ctx) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + mapped_for_run = MappedToolset(wrapped_for_run, self._tool_defs, self.name_map) + return RunToolset(mapped_for_run, ctx) @property def tool_defs(self) -> list[ToolDefinition]: @@ -535,12 +535,12 @@ class IndividuallyPreparedToolset(WrapperToolset[AgentDepsT]): prepare_func: ToolPrepareFunc[AgentDepsT] - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_wrapped = await self.wrapped.freeze_for_run(ctx) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) tool_defs: dict[str, ToolDefinition] = {} name_map: dict[str, str] = {} - for original_tool_def in frozen_wrapped.tool_defs: + for original_tool_def in wrapped_for_run.tool_defs: original_name = original_tool_def.name tool_def = await self.prepare_func(ctx, original_tool_def) if not tool_def: @@ -556,8 +556,8 @@ async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[Age tool_defs[new_name] = tool_def - frozen_mapped = await MappedToolset(frozen_wrapped, list(tool_defs.values()), name_map).freeze_for_run(ctx) - return FrozenToolset(frozen_mapped, ctx, original=self) + mapped_for_run = await MappedToolset(wrapped_for_run, list(tool_defs.values()), name_map).prepare_for_run(ctx) + return RunToolset(mapped_for_run, ctx, original=self) @dataclass(init=False) @@ -598,10 +598,10 @@ class ProcessedToolset(WrapperToolset[AgentDepsT]): process: ToolProcessFunc[AgentDepsT] - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: - frozen_wrapped = await self.wrapped.freeze_for_run(ctx) - processed = ProcessedToolset(frozen_wrapped, self.process) - return FrozenToolset(processed, ctx) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + processed = ProcessedToolset(wrapped_for_run, self.process) + return RunToolset(processed, ctx) async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any @@ -610,7 +610,7 @@ async def call_tool( @dataclass(init=False) -class FrozenToolset(WrapperToolset[AgentDepsT]): +class RunToolset(WrapperToolset[AgentDepsT]): """A toolset that is frozen for a specific run.""" ctx: RunContext[AgentDepsT] @@ -637,12 +637,12 @@ def __init__( def name(self) -> str: return self.wrapped.name - async def freeze_for_run(self, ctx: RunContext[AgentDepsT]) -> FrozenToolset[AgentDepsT]: + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: if ctx == self.ctx: return self else: ctx = replace(ctx, retries=self._retries) - return await self._original.freeze_for_run(ctx) + return await self._original.prepare_for_run(ctx) @property def tool_defs(self) -> list[ToolDefinition]: diff --git a/tests/test_examples.py b/tests/test_examples.py index b28704c53..0a223bb73 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -271,15 +271,15 @@ def tool_defs(self) -> list[ToolDefinition]: return [] def get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator: - return SchemaValidator(core_schema.any_schema()) + return SchemaValidator(core_schema.any_schema()) # pragma: lax no cover def max_retries_for_tool(self, name: str) -> int: - return 0 + return 0 # pragma: lax no cover async def call_tool( self, ctx: RunContext[Any], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any ) -> Any: - return None + return None # pragma: lax no cover text_responses: dict[str, str | ToolCallPart] = { diff --git a/tests/test_mcp.py b/tests/test_mcp.py index aed242907..83c4d33b7 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -74,7 +74,7 @@ def run_context(model: Model) -> RunContext[int]: async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: - tools = (await server.freeze_for_run(run_context)).tool_defs + tools = (await server.prepare_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) assert tools[0].name == 'celsius_to_fahrenheit' assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') @@ -94,7 +94,7 @@ async def test_reentrant_context_manager(): async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') async with server: - tools = (await server.freeze_for_run(run_context)).tool_defs + tools = (await server.prepare_for_run(run_context)).tool_defs assert all(tool.name.startswith('foo_') for tool in tools) @@ -102,7 +102,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): test_dir = Path(__file__).parent server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: - tools = (await server.freeze_for_run(run_context)).tool_defs + tools = (await server.prepare_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) @@ -269,7 +269,7 @@ async def test_log_level_unset(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) assert server.log_level is None async with server: - tools = (await server.freeze_for_run(run_context)).tool_defs + tools = (await server.prepare_for_run(run_context)).tool_defs assert len(tools) == snapshot(13) assert tools[10].name == 'get_log_level' diff --git a/tests/test_tools.py b/tests/test_tools.py index 0f456b5a9..2cb6237e9 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1155,6 +1155,7 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int: with pytest.raises(UnexpectedModelBehavior, match='Tool exceeded max retries count of 5'): agent.run_sync('Begin infinite retry loop!') - assert prepare_tools_retries == [0, 1, 2, 3, 4, 5] - assert prepare_retries == [0, 1, 2, 3, 4, 5] + # There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in. + assert prepare_tools_retries == [0, 0, 1, 2, 3, 4, 5] + assert prepare_retries == [0, 0, 1, 2, 3, 4, 5] assert call_retries == [0, 1, 2, 3, 4, 5] diff --git a/tests/test_toolset.py b/tests/test_toolset.py index b59d7b321..327c139e3 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolset.py @@ -21,7 +21,7 @@ def build_run_context(deps: T) -> RunContext[T]: return RunContext(deps=deps, model=TestModel(), usage=Usage(), prompt=None, messages=[], run_step=0) -async def test_function_toolset_freeze_for_run(): +async def test_function_toolset_prepare_for_run(): @dataclass class PrefixDeps: prefix: str | None = None @@ -57,8 +57,14 @@ def add(a: int, b: int) -> int: ) assert await toolset.call_tool(context, 'add', {'a': 1, 'b': 2}) == 3 + no_prefix_context = build_run_context(PrefixDeps()) + no_prefix_toolset = await toolset.prepare_for_run(no_prefix_context) + assert no_prefix_toolset.tool_names == toolset.tool_names + assert no_prefix_toolset.tool_defs == toolset.tool_defs + assert await no_prefix_toolset.call_tool(no_prefix_context, 'add', {'a': 1, 'b': 2}) == 3 + foo_context = build_run_context(PrefixDeps(prefix='foo')) - foo_toolset = await toolset.freeze_for_run(foo_context) + foo_toolset = await toolset.prepare_for_run(foo_context) assert foo_toolset.tool_names == snapshot(['foo_add']) assert foo_toolset.tool_defs == snapshot( [ @@ -79,12 +85,12 @@ def add(a: int, b: int) -> int: @toolset.tool def subtract(a: int, b: int) -> int: """Subtract two numbers""" - return a - b + return a - b # pragma: lax no cover assert foo_toolset.tool_names == snapshot(['foo_add']) bar_context = build_run_context(PrefixDeps(prefix='bar')) - bar_toolset = await toolset.freeze_for_run(bar_context) + bar_toolset = await toolset.prepare_for_run(bar_context) assert bar_toolset.tool_names == snapshot(['bar_add', 'subtract']) assert bar_toolset.tool_defs == snapshot( [ @@ -112,5 +118,5 @@ def subtract(a: int, b: int) -> int: ) assert await bar_toolset.call_tool(bar_context, 'add', {'a': 1, 'b': 2}) == 3 - bar_foo_toolset = await foo_toolset.freeze_for_run(bar_context) + bar_foo_toolset = await foo_toolset.prepare_for_run(bar_context) assert bar_foo_toolset == bar_toolset From ebd0b57838cbfee9692b233b602e42d41400aea9 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Sat, 21 Jun 2025 01:42:04 +0000 Subject: [PATCH 53/90] Make WrapperToolset abstract --- pydantic_ai_slim/pydantic_ai/toolset.py | 93 ++++++++++++------------- 1 file changed, 45 insertions(+), 48 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index aa433275b..5e1ae58a0 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -310,54 +310,6 @@ async def call_tool( return await self.output_schema.tools[name].processor.process(tool_args, ctx) -@dataclass -class WrapperToolset(AbstractToolset[AgentDepsT]): - """A toolset that wraps another toolset and delegates to it.""" - - wrapped: AbstractToolset[AgentDepsT] - - @property - def name(self) -> str: - return self.wrapped.name - - @property - def name_conflict_hint(self) -> str: - return self.wrapped.name_conflict_hint - - async def __aenter__(self) -> Self: - await self.wrapped.__aenter__() - return self - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None - ) -> bool | None: - return await self.wrapped.__aexit__(exc_type, exc_value, traceback) - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - raise NotImplementedError() - - @property - def tool_defs(self) -> list[ToolDefinition]: - return self.wrapped.tool_defs - - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self.wrapped.get_tool_args_validator(ctx, name) - - def max_retries_for_tool(self, name: str) -> int: - return self.wrapped.max_retries_for_tool(name) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) - - def set_mcp_sampling_model(self, model: Model) -> None: - self.wrapped.set_mcp_sampling_model(model) - - def __getattr__(self, item: str): - return getattr(self.wrapped, item) # pragma: no cover - - @dataclass(init=False) class CombinedToolset(AbstractToolset[AgentDepsT]): """A toolset that combines multiple toolsets.""" @@ -432,6 +384,51 @@ def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]: raise ValueError(f'Tool {name!r} not found in any toolset') from e +@dataclass +class WrapperToolset(AbstractToolset[AgentDepsT], ABC): + """A toolset that wraps another toolset and delegates to it.""" + + wrapped: AbstractToolset[AgentDepsT] + + @property + def name(self) -> str: + return self.wrapped.name + + @property + def name_conflict_hint(self) -> str: + return self.wrapped.name_conflict_hint + + async def __aenter__(self) -> Self: + await self.wrapped.__aenter__() + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + return await self.wrapped.__aexit__(exc_type, exc_value, traceback) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self.wrapped.tool_defs + + def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.wrapped.get_tool_args_validator(ctx, name) + + def max_retries_for_tool(self, name: str) -> int: + return self.wrapped.max_retries_for_tool(name) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) + + def set_mcp_sampling_model(self, model: Model) -> None: + self.wrapped.set_mcp_sampling_model(model) + + def __getattr__(self, item: str): + return getattr(self.wrapped, item) # pragma: no cover + + @dataclass class PrefixedToolset(WrapperToolset[AgentDepsT]): """A toolset that prefixes the names of the tools it contains.""" From 867bf68ed1acd60e89bfdf2fad6dd9477cdaedc6 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Sat, 21 Jun 2025 06:07:35 +0000 Subject: [PATCH 54/90] Introduce ToolDefinition.kind == 'pending' --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 257 +++++++++---------- pydantic_ai_slim/pydantic_ai/_output.py | 141 ++++++---- pydantic_ai_slim/pydantic_ai/agent.py | 20 +- pydantic_ai_slim/pydantic_ai/output.py | 26 +- pydantic_ai_slim/pydantic_ai/tools.py | 8 +- pydantic_ai_slim/pydantic_ai/toolset.py | 45 +++- tests/test_agent.py | 46 +++- tests/test_tools.py | 52 +++- tests/test_toolset.py | 4 +- 9 files changed, 375 insertions(+), 224 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 12b241d5c..5d38df0bd 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -4,6 +4,7 @@ import dataclasses import hashlib import json +from collections import defaultdict from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -16,14 +17,14 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._utils import is_async_callable, run_in_executor -from pydantic_ai.toolset import CombinedToolset, RunToolset +from pydantic_ai.toolset import AbstractToolset, CombinedToolset, RunToolset from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage -from .output import OutputDataT, OutputSpec +from .output import OutputDataT, OutputSpec, PendingToolCalls from .settings import ModelSettings, merge_model_settings -from .tools import RunContext +from .tools import RunContext, ToolDefinition, ToolKind if TYPE_CHECKING: pass @@ -397,7 +398,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]): _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( default=None, repr=False ) - _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] @@ -474,49 +474,111 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: async for event in self._events_iterator: yield event - async def _handle_tool_calls( + async def _handle_tool_calls( # noqa: C901 self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], tool_calls: list[_messages.ToolCallPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: - output_schema = ctx.deps.output_schema run_context = build_run_context(ctx) final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] - # TODO: Can we make output tools a toolset? How does CallToolsNode know the result is final, and not be sent back? + toolset = await CombinedToolset([ctx.deps.toolset, ctx.deps.output_toolset]).prepare_for_run(run_context) + + unknown_calls: list[_messages.ToolCallPart] = [] + tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list) + # TODO: Make Toolset.tool_defs a dict + tool_defs_by_name: dict[str, ToolDefinition] = {tool_def.name: tool_def for tool_def in toolset.tool_defs} + for call in tool_calls: + try: + tool_def = tool_defs_by_name[call.tool_name] + tool_calls_by_kind[tool_def.kind].append(call) + except KeyError: + unknown_calls.append(call) + # first, look for the output tool call - if isinstance(output_schema, _output.ToolOutputSchema): - for call, output_tool in output_schema.find_tool(tool_calls): + for call in tool_calls_by_kind['output']: + if final_result: + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Output tool not used - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + parts.append(part) + else: try: - result_data = await output_tool.process(call, run_context) - result_data = await _validate_output(result_data, ctx, call) + result_data = await _call_tool(toolset, call, run_context) except _output.ToolRetryError as e: - # TODO: Should only increment retry stuff once per node execution, not for each tool call - # Also, should increment the tool-specific retry count rather than the run retry count - ctx.state.increment_retries(ctx.deps.max_result_retries, e) parts.append(e.tool_retry) else: + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Final result processed.', + tool_call_id=call.tool_call_id, + ) + parts.append(part) final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - break # Then build the other request parts based on end strategy - tool_responses: list[_messages.ModelRequestPart] = self._tool_responses - async for event in process_function_tools( - tool_calls, - final_result and final_result.tool_name, - final_result and final_result.tool_call_id, - ctx, - tool_responses, - ): - yield event + if final_result and ctx.deps.end_strategy == 'early': + for call in tool_calls_by_kind['function']: + parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + async for event in process_function_tools( + toolset, + tool_calls_by_kind['function'], + ctx, + parts, + ): + yield event + + if unknown_calls: + ctx.state.increment_retries(ctx.deps.max_result_retries) + async for event in process_function_tools( + toolset, + unknown_calls, + ctx, + parts, + ): + yield event + + pending_calls: list[_messages.ToolCallPart] = [] + for call in tool_calls_by_kind['pending']: + if final_result: + parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + yield _messages.FunctionToolCallEvent(call) + pending_calls.append(call) + + if pending_calls: + if not ctx.deps.output_schema.pending_tool_calls: + raise exceptions.UserError( + 'There are pending tool calls but PendingToolCalls is not among output types.' + ) + + pending_tool_names = [call.tool_name for call in pending_calls] + pending_tool_defs = { + tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in pending_tool_names + } + output_data = cast(NodeRunEndT, PendingToolCalls(pending_calls, pending_tool_defs)) + final_result = result.FinalResult(output_data) if final_result: - self._next_node = self._handle_final_result(ctx, final_result, tool_responses) + self._next_node = self._handle_final_result(ctx, final_result, parts) else: - if tool_responses: - parts.extend(tool_responses) instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( _messages.ModelRequest(parts=parts, instructions=instructions) @@ -581,10 +643,9 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str: return hashlib.sha1(identifier).hexdigest()[:6] -async def process_function_tools( # noqa C901 +async def process_function_tools( + toolset: AbstractToolset[DepsT], tool_calls: list[_messages.ToolCallPart], - output_tool_name: str | None, - output_tool_call_id: str | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], output_parts: list[_messages.ModelRequestPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: @@ -594,68 +655,15 @@ async def process_function_tools( # noqa C901 Because async iterators can't have return values, we use `output_parts` as an output argument. """ - stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early' - output_schema = ctx.deps.output_schema - - # we rely on the fact that if we found a result, it's the first output tool in the last - found_used_output_tool = False + run_context = build_run_context(ctx) calls_to_run: list[_messages.ToolCallPart] = [] call_index_to_event_id: dict[int, str] = {} for call in tool_calls: - if ( - call.tool_name == output_tool_name - and call.tool_call_id == output_tool_call_id - and not found_used_output_tool - ): - found_used_output_tool = True - output_parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Final result processed.', - tool_call_id=call.tool_call_id, - ) - ) - elif call.tool_name in output_schema.tools: # TODO: Check on toolset? - # if tool_name is in output_schema, it means we found a output tool but an error occurred in - # validation, we don't add another part here - if output_tool_name is not None: - yield _messages.FunctionToolCallEvent(call) - if found_used_output_tool: - content = 'Output tool not used - a final result was already processed.' - else: - # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part - content = 'Output tool not used - result failed validation.' - part = _messages.ToolReturnPart( - tool_name=call.tool_name, - content=content, - tool_call_id=call.tool_call_id, - ) - yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id) - output_parts.append(part) - elif call.tool_name in ctx.deps.toolset.tool_names: - if stub_function_tools: - output_parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - event = _messages.FunctionToolCallEvent(call) - yield event - call_index_to_event_id[len(calls_to_run)] = event.call_id - calls_to_run.append(call) - else: - yield _messages.FunctionToolCallEvent(call) - - part = await _unknown_tool(call.tool_name, call.tool_call_id, ctx) - yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id) - output_parts.append(part) - - if not calls_to_run: - return + event = _messages.FunctionToolCallEvent(call) + yield event + call_index_to_event_id[len(calls_to_run)] = event.call_id + calls_to_run.append(call) user_parts: list[_messages.UserPromptPart] = [] @@ -669,7 +677,7 @@ async def process_function_tools( # noqa C901 }, ): tasks = [ - asyncio.create_task(_execute_tool_call(call, ctx, ctx.deps.tracer), name=call.tool_name) + asyncio.create_task(_call_function_tool(toolset, call, run_context, ctx.deps.tracer), name=call.tool_name) for call in calls_to_run ] @@ -721,9 +729,10 @@ def process_content(content: Any) -> Any: output_parts.extend(user_parts) -async def _execute_tool_call( +async def _call_function_tool( + toolset: AbstractToolset[DepsT], tool_call: _messages.ToolCallPart, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + run_context: RunContext[DepsT], tracer: Tracer, ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: """Run the tool function asynchronously. @@ -749,57 +758,43 @@ async def _execute_tool_call( ), } - run_context = build_run_context(ctx) - toolset = ctx.deps.toolset with tracer.start_as_current_span('running tool', attributes=span_attributes): - run_context = dataclasses.replace(run_context, tool_call_id=tool_call.tool_call_id) - try: - args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args) - except ValidationError as e: - return _messages.RetryPromptPart( + response_content = await _call_tool(toolset, tool_call, run_context) + except _output.ToolRetryError as e: + return e.tool_retry + else: + return _messages.ToolReturnPart( + tool_name=tool_call.tool_name, + content=response_content, + tool_call_id=tool_call.tool_call_id, + ) + + +async def _call_tool( + toolset: AbstractToolset[DepsT], tool_call: _messages.ToolCallPart, run_context: RunContext[DepsT] +) -> Any: + run_context = dataclasses.replace(run_context, tool_call_id=tool_call.tool_call_id) + + try: + args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args) + response_content = await toolset.call_tool(run_context, tool_call.tool_name, args_dict) + except (ValidationError, exceptions.ModelRetry) as e: + if isinstance(e, ValidationError): + m = _messages.RetryPromptPart( tool_name=tool_call.tool_name, content=e.errors(include_url=False, include_context=False), tool_call_id=tool_call.tool_call_id, ) - try: - response_content = await toolset.call_tool(run_context, tool_call.tool_name, args_dict) - except exceptions.ModelRetry as e: - return _messages.RetryPromptPart( + else: + m = _messages.RetryPromptPart( tool_name=tool_call.tool_name, content=e.message, tool_call_id=tool_call.tool_call_id, ) + raise _output.ToolRetryError(m) - return _messages.ToolReturnPart( - tool_name=tool_call.tool_name, - content=response_content, - tool_call_id=tool_call.tool_call_id, - ) - - -async def _unknown_tool( - tool_name: str, - tool_call_id: str, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> _messages.RetryPromptPart: - ctx.state.increment_retries(ctx.deps.max_result_retries) - - tool_names = [ - *ctx.deps.toolset.tool_names, - *ctx.deps.output_toolset.tool_names, - ] - - if tool_names: - msg = f'Available tools: {", ".join(tool_names)}' - else: - msg = 'No tools available.' - - return _messages.RetryPromptPart( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=f'Unknown tool name: {tool_name!r}. {msg}', - ) + return response_content async def _validate_output( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 230fb4d38..163a14142 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -20,6 +20,7 @@ OutputMode, OutputSpec, OutputTypeOrFunction, + PendingToolCalls, PromptedStructuredOutput, StructuredOutputMode, TextOutput, @@ -83,6 +84,7 @@ async def validate( result: T, tool_call: _messages.ToolCallPart | None, run_context: RunContext[AgentDepsT], + wrap_validation_errors: bool = True, ) -> T: """Validate a result but calling the function. @@ -90,6 +92,7 @@ async def validate( result: The result data after Pydantic validation the message content. tool_call: The original tool call message, `None` if there was no tool call. run_context: The current run context. + wrap_validation_errors: If true, wrap the validation errors in a retry message. Returns: Result of either the validated result data (ok) or a retry message (Err). @@ -112,16 +115,22 @@ async def validate( function = cast(Callable[[Any], T], self.function) result_data = await _utils.run_in_executor(function, *args) except ModelRetry as r: - m = _messages.RetryPromptPart(content=r.message) - if tool_call is not None: - m.tool_name = tool_call.tool_name - m.tool_call_id = tool_call.tool_call_id - raise ToolRetryError(m) from r + if wrap_validation_errors: + m = _messages.RetryPromptPart(content=r.message) + if tool_call is not None: + m.tool_name = tool_call.tool_name + m.tool_call_id = tool_call.tool_call_id + raise ToolRetryError(m) from r + else: + raise r else: return result_data +@dataclass class BaseOutputSchema(ABC, Generic[OutputDataT]): + pending_tool_calls: bool + @abstractmethod def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: raise NotImplementedError() @@ -161,7 +170,7 @@ def build( ) -> BaseOutputSchema[OutputDataT]: ... @classmethod - def build( + def build( # noqa: C901 cls, output_spec: OutputSpec[OutputDataT], *, @@ -171,37 +180,51 @@ def build( strict: bool | None = None, ) -> BaseOutputSchema[OutputDataT]: """Build an OutputSchema dataclass from an output type.""" - if output_spec is str: - return PlainTextOutputSchema() + raw_outputs = _flatten_output_spec(output_spec) + + outputs = [output for output in raw_outputs if output is not PendingToolCalls] + pending_tool_calls = len(outputs) < len(raw_outputs) + if output := next((output for output in outputs if isinstance(output, ModelStructuredOutput)), None): + if len(outputs) > 1: + raise UserError('ModelStructuredOutput cannot be mixed with other output types.') - if isinstance(output_spec, ModelStructuredOutput): return ModelStructuredOutputSchema( - cls._build_processor( - output_spec.outputs, - name=output_spec.name, - description=output_spec.description, - ) + processor=cls._build_processor( + output.outputs, + name=output.name, + description=output.description, + ), + pending_tool_calls=pending_tool_calls, ) - elif isinstance(output_spec, PromptedStructuredOutput): + elif output := next((output for output in outputs if isinstance(output, PromptedStructuredOutput)), None): + if len(outputs) > 1: + raise UserError('PromptedStructuredOutput cannot be mixed with other output types.') + return PromptedStructuredOutputSchema( - cls._build_processor( - output_spec.outputs, - name=output_spec.name, - description=output_spec.description, + processor=cls._build_processor( + output.outputs, + name=output.name, + description=output.description, ), - template=output_spec.template, + template=output.template, + pending_tool_calls=pending_tool_calls, ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = [] - for output in _flatten_output_spec(output_spec): + for output in outputs: if output is str: text_outputs.append(cast(type[str], output)) elif isinstance(output, TextOutput): text_outputs.append(output) elif isinstance(output, ToolOutput): tool_outputs.append(output) + elif isinstance(output, (ModelStructuredOutput, PromptedStructuredOutput)): + # We can never get here because these are checked for above. + raise UserError( + 'ModelStructuredOutput and PromptedStructuredOutput must be the only output types.' + ) # pragma: no cover else: other_outputs.append(output) @@ -217,17 +240,20 @@ def build( text_output_schema = PlainTextOutputProcessor(text_output.output_function) if len(tools) == 0: - return PlainTextOutputSchema(text_output_schema) + return PlainTextOutputSchema(processor=text_output_schema, pending_tool_calls=pending_tool_calls) else: - return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools) + return ToolOrTextOutputSchema( + processor=text_output_schema, tools=tools, pending_tool_calls=pending_tool_calls + ) if len(tool_outputs) > 0: - return ToolOutputSchema(tools) + return ToolOutputSchema(tools=tools, pending_tool_calls=pending_tool_calls) if len(other_outputs) > 0: schema = OutputSchemaWithoutMode( processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), tools=tools, + pending_tool_calls=pending_tool_calls, ) if default_mode: schema = schema.with_default_mode(default_mode) @@ -317,17 +343,19 @@ def __init__( self, processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], tools: dict[str, OutputTool[OutputDataT]], + pending_tool_calls: bool, ): + super().__init__(pending_tool_calls) self.processor = processor self._tools = tools def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: if mode == 'model_structured': - return ModelStructuredOutputSchema(self.processor) + return ModelStructuredOutputSchema(processor=self.processor, pending_tool_calls=self.pending_tool_calls) elif mode == 'prompted_structured': - return PromptedStructuredOutputSchema(self.processor) + return PromptedStructuredOutputSchema(processor=self.processor, pending_tool_calls=self.pending_tool_calls) elif mode == 'tool': - return ToolOutputSchema(self.tools) + return ToolOutputSchema(tools=self.tools, pending_tool_calls=self.pending_tool_calls) else: assert_never(mode) @@ -489,7 +517,8 @@ async def process( class ToolOutputSchema(OutputSchema[OutputDataT]): _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) - def __init__(self, tools: dict[str, OutputTool[OutputDataT]]): + def __init__(self, tools: dict[str, OutputTool[OutputDataT]], pending_tool_calls: bool): + super().__init__(pending_tool_calls) self._tools = tools @property @@ -532,9 +561,10 @@ def __init__( self, processor: PlainTextOutputProcessor[OutputDataT] | None, tools: dict[str, OutputTool[OutputDataT]], + pending_tool_calls: bool, ): + super().__init__(tools=tools, pending_tool_calls=pending_tool_calls) self.processor = processor - self._tools = tools @property def mode(self) -> OutputMode: @@ -567,7 +597,7 @@ async def process( class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]): object_def: OutputObjectDefinition outer_typed_dict_key: str | None = None - _validator: SchemaValidator + validator: SchemaValidator _function_schema: _function_schema.FunctionSchema | None = None def __init__( @@ -580,7 +610,7 @@ def __init__( ): if inspect.isfunction(output) or inspect.ismethod(output): self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema) - self._validator = self._function_schema.validator + self.validator = self._function_schema.validator json_schema = self._function_schema.json_schema json_schema['description'] = self._function_schema.description else: @@ -596,7 +626,7 @@ def __init__( type_adapter = TypeAdapter(response_data_typed_dict) # Really a PluggableSchemaValidator, but it's API-compatible - self._validator = cast(SchemaValidator, type_adapter.validator) + self.validator = cast(SchemaValidator, type_adapter.validator) json_schema = _utils.check_object_json_schema( type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) @@ -637,11 +667,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(data, str): - output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) - else: - output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + output = self.validate(data, allow_partial) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -651,20 +677,40 @@ async def process( else: raise # pragma: lax no cover + try: + output = await self.call(output, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise # pragma: lax no cover + + return output + + def validate( + self, + data: str | dict[str, Any] | None, + allow_partial: bool = False, + ) -> dict[str, Any]: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + else: + return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + + async def call( + self, + output: Any, + run_context: RunContext[AgentDepsT], + ): if k := self.outer_typed_dict_key: output = output[k] if self._function_schema: - try: - output = await self._function_schema.call(output, run_context) - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - content=r.message, - ) - raise ToolRetryError(m) from r - else: - raise # pragma: lax no cover + output = await self._function_schema.call(output, run_context) return output @@ -860,6 +906,7 @@ def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], parameters_json_schema=object_def.json_schema, strict=object_def.strict, outer_typed_dict_key=processor.outer_typed_dict_key, + kind='output', ) async def process( diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index c3a5ed230..c38d5c448 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -651,9 +651,16 @@ async def main(): output_type_ = output_type or self.output_type + # We consider it a user error if a user tries to restrict the result type while having an output validator that + # may change the result type from the restricted type to something else. Therefore, we consider the following + # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. + output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + output_toolset = self._output_toolset - if output_schema != self._output_schema: - output_toolset = OutputToolset[AgentDepsT](output_schema, max_retries=self._max_result_retries) + if output_schema != self._output_schema or output_validators: + output_toolset = OutputToolset[AgentDepsT]( + output_schema, max_retries=self._max_result_retries, output_validators=output_validators + ) # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( @@ -684,11 +691,6 @@ async def main(): # This will raise errors for any name conflicts CombinedToolset([run_output_toolset, run_toolset]) - # We consider it a user error if a user tries to restrict the result type while having an output validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) - model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() @@ -1069,10 +1071,10 @@ async def on_complete() -> None: ] parts: list[_messages.ModelRequestPart] = [] + # TODO: Make this work again. We may have pulled too much out of process_function_tools :) async for _event in _agent_graph.process_function_tools( + graph_ctx.deps.toolset, tool_calls, - final_result_details.tool_name, - final_result_details.tool_call_id, graph_ctx, parts, ): diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 45d786dd0..9a84f2f9b 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -8,8 +8,8 @@ from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin -from .messages import RetryPromptPart -from .tools import RunContext +from .messages import RetryPromptPart, ToolCallPart +from .tools import RunContext, ToolDefinition OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) """Covariant type variable for the result data type of a run.""" @@ -232,6 +232,14 @@ def split_into_words(text: str) -> list[str]: output_function: TextOutputFunction[OutputDataT] +@dataclass +class PendingToolCalls: + """Output type for calls to tools defined as pending.""" + + tool_calls: list[ToolCallPart] + tool_defs: dict[str, ToolDefinition] + + def _get_union_args(tp: Any) -> tuple[Any, ...]: """Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple.""" if typing_objects.is_typealiastype(tp): @@ -266,15 +274,23 @@ def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,) ) -OutputSpec = TypeAliasType( - 'OutputSpec', +OutputSpecItem = TypeAliasType( + 'OutputSpecItem', Union[ OutputTypeOrFunction[T_co], ToolOutput[T_co], ModelStructuredOutput[T_co], PromptedStructuredOutput[T_co], TextOutput[T_co], - Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], + ], + type_params=(T_co,), +) + +OutputSpec = TypeAliasType( + 'OutputSpec', + Union[ + OutputSpecItem[T_co], + Sequence[OutputSpecItem[T_co]], ], type_params=(T_co,), ) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 200b17506..2575b025f 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from collections.abc import Awaitable, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Callable, Generic, Literal, Union from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue @@ -324,6 +324,9 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` """ +ToolKind: TypeAlias = Literal['function', 'output', 'pending'] +"""Kind of tool.""" + @dataclass(repr=False) class ToolDefinition: @@ -359,4 +362,7 @@ class ToolDefinition: Note: this is currently only supported by OpenAI models. """ + kind: ToolKind = field(default='function') + """The kind of tool.""" + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index 5e1ae58a0..01a9400c3 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -7,14 +7,14 @@ from dataclasses import dataclass, field, replace from functools import partial from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, overload from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema from pydantic_core import SchemaValidator from typing_extensions import Never, Self -from ._output import BaseOutputSchema +from ._output import BaseOutputSchema, OutputValidator from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError from .tools import ( @@ -74,13 +74,14 @@ def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> Sch raise NotImplementedError() def validate_tool_args( - self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False ) -> dict[str, Any]: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' validator = self.get_tool_args_validator(ctx, name) if isinstance(args, str): - return validator.validate_json(args or '{}') + return validator.validate_json(args or '{}', allow_partial=pyd_allow_partial) else: - return validator.validate_python(args or {}) + return validator.validate_python(args or {}, allow_partial=pyd_allow_partial) @abstractmethod def max_retries_for_tool(self, name: str) -> int: @@ -290,15 +291,15 @@ class OutputToolset(AbstractToolset[AgentDepsT]): """A toolset that contains output tools.""" output_schema: BaseOutputSchema[Any] - max_retries: int = field(default=1) + max_retries: int = field(default=1) # TODO: Test this works + output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list) @property def tool_defs(self) -> list[ToolDefinition]: return [tool.tool_def for tool in self.output_schema.tools.values()] def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - # TODO: Should never be called for an output tool? - return self.output_schema.tools[name].processor._validator # pyright: ignore[reportPrivateUsage] + return self.output_schema.tools[name].processor.validator def max_retries_for_tool(self, name: str) -> int: return self.max_retries @@ -306,8 +307,10 @@ def max_retries_for_tool(self, name: str) -> int: async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any ) -> Any: - # TODO: Should never be called for an output tool? - return await self.output_schema.tools[name].processor.process(tool_args, ctx) + output = await self.output_schema.tools[name].processor.call(tool_args, ctx) + for validator in self.output_validators: + output = await validator.validate(output, None, ctx, wrap_validation_errors=False) + return output @dataclass(init=False) @@ -650,19 +653,23 @@ def tool_names(self) -> list[str]: return self._tool_names def validate_tool_args( - self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False ) -> dict[str, Any]: - ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0)) try: - return super().validate_tool_args(ctx, name, args) + self._validate_tool_name(name) + + ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0)) + return super().validate_tool_args(ctx, name, args, allow_partial) except ValidationError as e: return self._on_error(name, e) async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any ) -> Any: - ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0)) try: + self._validate_tool_name(name) + + ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0)) return await super().call_tool(ctx, name, tool_args, *args, **kwargs) except ModelRetry as e: return self._on_error(name, e) @@ -675,3 +682,13 @@ def _on_error(self, name: str, e: Exception) -> Never: else: self._retries[name] = current_retry + 1 raise e + + def _validate_tool_name(self, name: str) -> None: + if name in self.tool_names: + return + + if self.tool_names: + msg = f'Available tools: {", ".join(self.tool_names)}' + else: + msg = 'No tools available.' + raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') diff --git a/tests/test_agent.py b/tests/test_agent.py index 7dc4e8b0a..49203c198 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -393,6 +393,7 @@ def test_response_tuple(): 'type': 'object', }, outer_typed_dict_key='response', + kind='output', ) ] ) @@ -466,6 +467,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Foo', 'type': 'object', }, + kind='output', ) ] ) @@ -545,6 +547,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Foo', 'type': 'object', }, + kind='output', ), ToolDefinition( name='final_result_Bar', @@ -555,6 +558,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Bar', 'type': 'object', }, + kind='output', ), ] ) @@ -586,6 +590,7 @@ class MyOutput(BaseModel): 'title': 'MyOutput', 'type': 'object', }, + kind='output', ) ] ) @@ -632,6 +637,7 @@ class Bar(BaseModel): }, outer_typed_dict_key='response', strict=False, + kind='output', ) ] ) @@ -670,6 +676,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -709,6 +716,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -749,6 +757,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -790,6 +799,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -986,6 +996,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -1024,6 +1035,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -1062,6 +1074,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ), ToolDefinition( name='final_result_Weather', @@ -1072,6 +1085,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'title': 'Weather', 'type': 'object', }, + kind='output', ), ] ) @@ -1248,6 +1262,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ), ToolDefinition( name='return_weather', @@ -1258,6 +1273,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'title': 'Weather', 'type': 'object', }, + kind='output', ), ] ) @@ -2269,12 +2285,6 @@ def another_tool(y: int) -> int: tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), - RetryPromptPart( - tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", - timestamp=IsNow(tz=timezone.utc), - tool_call_id=IsStr(), - ), ToolReturnPart( tool_name='regular_tool', content=42, @@ -2284,6 +2294,12 @@ def another_tool(y: int) -> int: ToolReturnPart( tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", + tool_name='unknown_tool', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ), ] ), ] @@ -2347,16 +2363,16 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), - timestamp=IsNow(tz=timezone.utc), + timestamp=IsDatetime(), ), ToolReturnPart( tool_name='another_tool', @@ -2413,11 +2429,13 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: # Verify we got appropriate tool returns assert result.new_messages()[-1].parts == snapshot( [ - ToolReturnPart( + RetryPromptPart( + content=[ + {'type': 'missing', 'loc': ('value',), 'msg': 'Field required', 'input': {'bad_value': 'first'}} + ], tool_name='final_result', tool_call_id='first', - content='Output tool not used - result failed validation.', - timestamp=IsNow(tz=timezone.utc), + timestamp=IsDatetime(), ), ToolReturnPart( tool_name='final_result', diff --git a/tests/test_tools.py b/tests/test_tools.py index 2cb6237e9..6bc0eecfc 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -16,9 +16,11 @@ from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import ToolOutput +from pydantic_ai.output import PendingToolCalls, ToolOutput from pydantic_ai.tools import ToolDefinition +from .conftest import IsStr + def test_tool_no_ctx(): agent = Agent(TestModel()) @@ -106,6 +108,7 @@ def test_docstring_google(docstring_format: Literal['google', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) keys = list(json_schema.keys()) @@ -142,6 +145,7 @@ def test_docstring_sphinx(docstring_format: Literal['sphinx', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -181,6 +185,7 @@ def test_docstring_numpy(docstring_format: Literal['numpy', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -220,6 +225,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -257,6 +263,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -300,6 +307,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -331,6 +339,7 @@ def test_only_returns_type(): 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -353,6 +362,7 @@ def test_docstring_unknown(): 'parameters_json_schema': {'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -393,6 +403,7 @@ def test_docstring_google_no_body(docstring_format: Literal['google', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -426,6 +437,7 @@ def takes_just_model(model: Foo) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -468,6 +480,7 @@ def takes_just_model(model: Foo, z: int) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -807,6 +820,7 @@ def test_suppress_griffe_logging(caplog: LogCaptureFixture): 'outer_typed_dict_key': None, 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'strict': None, + 'kind': 'function', } ) @@ -876,6 +890,7 @@ def my_tool_plain(*, a: int = 1, b: int) -> int: 'type': 'object', }, 'strict': None, + 'kind': 'function', }, { 'description': '', @@ -888,6 +903,7 @@ def my_tool_plain(*, a: int = 1, b: int) -> int: 'type': 'object', }, 'strict': None, + 'kind': 'function', }, ] ) @@ -972,6 +988,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = 'type': 'object', }, 'strict': None, + 'kind': 'function', }, { 'description': '', @@ -982,6 +999,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = 'type': 'object', }, 'strict': None, + 'kind': 'function', }, ] ) @@ -1017,6 +1035,7 @@ def get_score(data: Data) -> int: ... # pragma: no branch }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -1159,3 +1178,34 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int: assert prepare_tools_retries == [0, 0, 1, 2, 3, 4, 5] assert prepare_retries == [0, 0, 1, 2, 3, 4, 5] assert call_retries == [0, 1, 2, 3, 4, 5] + + +def test_pending_tool(): + agent = Agent(TestModel(), output_type=[str, PendingToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='pending') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 + + result = agent.run_sync('Hello') + assert result.output == snapshot( + PendingToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='pending', + ) + }, + ) + ) diff --git a/tests/test_toolset.py b/tests/test_toolset.py index 327c139e3..c58dce68e 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolset.py @@ -80,7 +80,7 @@ def add(a: int, b: int) -> int: ) ] ) - assert await foo_toolset.call_tool(foo_context, 'add', {'a': 1, 'b': 2}) == 3 + assert await foo_toolset.call_tool(foo_context, 'foo_add', {'a': 1, 'b': 2}) == 3 @toolset.tool def subtract(a: int, b: int) -> int: @@ -116,7 +116,7 @@ def subtract(a: int, b: int) -> int: ), ] ) - assert await bar_toolset.call_tool(bar_context, 'add', {'a': 1, 'b': 2}) == 3 + assert await bar_toolset.call_tool(bar_context, 'bar_add', {'a': 1, 'b': 2}) == 3 bar_foo_toolset = await foo_toolset.prepare_for_run(bar_context) assert bar_foo_toolset == bar_toolset From c1115aece71a36f9c51f24e7f8a086c75d5d7630 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Sat, 21 Jun 2025 06:16:08 +0000 Subject: [PATCH 55/90] Rename pending tools to deferred tools --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 22 +++++------ pydantic_ai_slim/pydantic_ai/_output.py | 40 ++++++++++---------- pydantic_ai_slim/pydantic_ai/output.py | 2 +- pydantic_ai_slim/pydantic_ai/tools.py | 2 +- tests/test_tools.py | 12 +++--- 5 files changed, 40 insertions(+), 38 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 5d38df0bd..4bc489dfc 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -22,7 +22,7 @@ from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage -from .output import OutputDataT, OutputSpec, PendingToolCalls +from .output import DeferredToolCalls, OutputDataT, OutputSpec from .settings import ModelSettings, merge_model_settings from .tools import RunContext, ToolDefinition, ToolKind @@ -549,8 +549,8 @@ async def _handle_tool_calls( # noqa: C901 ): yield event - pending_calls: list[_messages.ToolCallPart] = [] - for call in tool_calls_by_kind['pending']: + deferred_calls: list[_messages.ToolCallPart] = [] + for call in tool_calls_by_kind['deferred']: if final_result: parts.append( _messages.ToolReturnPart( @@ -561,19 +561,19 @@ async def _handle_tool_calls( # noqa: C901 ) else: yield _messages.FunctionToolCallEvent(call) - pending_calls.append(call) + deferred_calls.append(call) - if pending_calls: - if not ctx.deps.output_schema.pending_tool_calls: + if deferred_calls: + if not ctx.deps.output_schema.deferred_tool_calls: raise exceptions.UserError( - 'There are pending tool calls but PendingToolCalls is not among output types.' + 'There are pending tool calls but DeferredToolCalls is not among output types.' ) - pending_tool_names = [call.tool_name for call in pending_calls] - pending_tool_defs = { - tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in pending_tool_names + deferred_tool_names = [call.tool_name for call in deferred_calls] + deferred_tool_defs = { + tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in deferred_tool_names } - output_data = cast(NodeRunEndT, PendingToolCalls(pending_calls, pending_tool_defs)) + output_data = cast(NodeRunEndT, DeferredToolCalls(deferred_calls, deferred_tool_defs)) final_result = result.FinalResult(output_data) if final_result: diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 163a14142..22020b145 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -15,12 +15,12 @@ from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, UserError from .output import ( + DeferredToolCalls, ModelStructuredOutput, OutputDataT, OutputMode, OutputSpec, OutputTypeOrFunction, - PendingToolCalls, PromptedStructuredOutput, StructuredOutputMode, TextOutput, @@ -129,7 +129,7 @@ async def validate( @dataclass class BaseOutputSchema(ABC, Generic[OutputDataT]): - pending_tool_calls: bool + deferred_tool_calls: bool @abstractmethod def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: @@ -182,8 +182,8 @@ def build( # noqa: C901 """Build an OutputSchema dataclass from an output type.""" raw_outputs = _flatten_output_spec(output_spec) - outputs = [output for output in raw_outputs if output is not PendingToolCalls] - pending_tool_calls = len(outputs) < len(raw_outputs) + outputs = [output for output in raw_outputs if output is not DeferredToolCalls] + deferred_tool_calls = len(outputs) < len(raw_outputs) if output := next((output for output in outputs if isinstance(output, ModelStructuredOutput)), None): if len(outputs) > 1: raise UserError('ModelStructuredOutput cannot be mixed with other output types.') @@ -194,7 +194,7 @@ def build( # noqa: C901 name=output.name, description=output.description, ), - pending_tool_calls=pending_tool_calls, + deferred_tool_calls=deferred_tool_calls, ) elif output := next((output for output in outputs if isinstance(output, PromptedStructuredOutput)), None): if len(outputs) > 1: @@ -207,7 +207,7 @@ def build( # noqa: C901 description=output.description, ), template=output.template, - pending_tool_calls=pending_tool_calls, + deferred_tool_calls=deferred_tool_calls, ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] @@ -240,20 +240,20 @@ def build( # noqa: C901 text_output_schema = PlainTextOutputProcessor(text_output.output_function) if len(tools) == 0: - return PlainTextOutputSchema(processor=text_output_schema, pending_tool_calls=pending_tool_calls) + return PlainTextOutputSchema(processor=text_output_schema, deferred_tool_calls=deferred_tool_calls) else: return ToolOrTextOutputSchema( - processor=text_output_schema, tools=tools, pending_tool_calls=pending_tool_calls + processor=text_output_schema, tools=tools, deferred_tool_calls=deferred_tool_calls ) if len(tool_outputs) > 0: - return ToolOutputSchema(tools=tools, pending_tool_calls=pending_tool_calls) + return ToolOutputSchema(tools=tools, deferred_tool_calls=deferred_tool_calls) if len(other_outputs) > 0: schema = OutputSchemaWithoutMode( processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), tools=tools, - pending_tool_calls=pending_tool_calls, + deferred_tool_calls=deferred_tool_calls, ) if default_mode: schema = schema.with_default_mode(default_mode) @@ -343,19 +343,21 @@ def __init__( self, processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], tools: dict[str, OutputTool[OutputDataT]], - pending_tool_calls: bool, + deferred_tool_calls: bool, ): - super().__init__(pending_tool_calls) + super().__init__(deferred_tool_calls) self.processor = processor self._tools = tools def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: if mode == 'model_structured': - return ModelStructuredOutputSchema(processor=self.processor, pending_tool_calls=self.pending_tool_calls) + return ModelStructuredOutputSchema(processor=self.processor, deferred_tool_calls=self.deferred_tool_calls) elif mode == 'prompted_structured': - return PromptedStructuredOutputSchema(processor=self.processor, pending_tool_calls=self.pending_tool_calls) + return PromptedStructuredOutputSchema( + processor=self.processor, deferred_tool_calls=self.deferred_tool_calls + ) elif mode == 'tool': - return ToolOutputSchema(tools=self.tools, pending_tool_calls=self.pending_tool_calls) + return ToolOutputSchema(tools=self.tools, deferred_tool_calls=self.deferred_tool_calls) else: assert_never(mode) @@ -517,8 +519,8 @@ async def process( class ToolOutputSchema(OutputSchema[OutputDataT]): _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) - def __init__(self, tools: dict[str, OutputTool[OutputDataT]], pending_tool_calls: bool): - super().__init__(pending_tool_calls) + def __init__(self, tools: dict[str, OutputTool[OutputDataT]], deferred_tool_calls: bool): + super().__init__(deferred_tool_calls) self._tools = tools @property @@ -561,9 +563,9 @@ def __init__( self, processor: PlainTextOutputProcessor[OutputDataT] | None, tools: dict[str, OutputTool[OutputDataT]], - pending_tool_calls: bool, + deferred_tool_calls: bool, ): - super().__init__(tools=tools, pending_tool_calls=pending_tool_calls) + super().__init__(tools=tools, deferred_tool_calls=deferred_tool_calls) self.processor = processor @property diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 9a84f2f9b..7e921c653 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -233,7 +233,7 @@ def split_into_words(text: str) -> list[str]: @dataclass -class PendingToolCalls: +class DeferredToolCalls: """Output type for calls to tools defined as pending.""" tool_calls: list[ToolCallPart] diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 2575b025f..201081c5f 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -324,7 +324,7 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` """ -ToolKind: TypeAlias = Literal['function', 'output', 'pending'] +ToolKind: TypeAlias = Literal['function', 'output', 'deferred'] """Kind of tool.""" diff --git a/tests/test_tools.py b/tests/test_tools.py index 6bc0eecfc..a31369db4 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -16,7 +16,7 @@ from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import PendingToolCalls, ToolOutput +from pydantic_ai.output import DeferredToolCalls, ToolOutput from pydantic_ai.tools import ToolDefinition from .conftest import IsStr @@ -1180,11 +1180,11 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int: assert call_retries == [0, 1, 2, 3, 4, 5] -def test_pending_tool(): - agent = Agent(TestModel(), output_type=[str, PendingToolCalls]) +def test_deferred_tool(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: - return replace(tool_def, kind='pending') + return replace(tool_def, kind='deferred') @agent.tool_plain(prepare=prepare_tool) def my_tool(x: int) -> int: @@ -1192,7 +1192,7 @@ def my_tool(x: int) -> int: result = agent.run_sync('Hello') assert result.output == snapshot( - PendingToolCalls( + DeferredToolCalls( tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], tool_defs={ 'my_tool': ToolDefinition( @@ -1204,7 +1204,7 @@ def my_tool(x: int) -> int: 'required': ['x'], 'type': 'object', }, - kind='pending', + kind='deferred', ) }, ) From a2f69dff18b8616724f6a18616b5e238906845c4 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 24 Jun 2025 18:25:23 +0000 Subject: [PATCH 56/90] Fix retries --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 5bf7692b0..eb3d82c5c 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -486,7 +486,7 @@ async def _handle_tool_calls( # noqa: C901 final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] - toolset = await CombinedToolset([ctx.deps.toolset, ctx.deps.output_toolset]).prepare_for_run(run_context) + toolset = CombinedToolset([ctx.deps.toolset, ctx.deps.output_toolset]) unknown_calls: list[_messages.ToolCallPart] = [] tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list) From 0e0bf35e533e4f973495d4372a2aa31010e0b355 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 24 Jun 2025 18:35:35 +0000 Subject: [PATCH 57/90] Remove duplicate cassettes --- ..._anthropic_prompted_structured_output.yaml | 161 ------- ...c_prompted_structured_output_multiple.yaml | 66 --- .../test_gemini_model_structured_output.yaml | 79 ---- ...mini_model_structured_output_multiple.yaml | 120 ----- ...est_gemini_prompted_structured_output.yaml | 74 --- ...i_prompted_structured_output_multiple.yaml | 73 --- ...prompted_structured_output_with_tools.yaml | 157 ------- ...ogle_model_structured_output_multiple.yaml | 138 ------ ...est_google_prompted_structured_output.yaml | 78 --- ...e_prompted_structured_output_multiple.yaml | 77 --- ...prompted_structured_output_with_tools.yaml | 164 ------- .../test_openai_model_structured_output.yaml | 223 --------- ...enai_model_structured_output_multiple.yaml | 293 ------------ ...est_openai_prompted_structured_output.yaml | 209 --------- ...i_prompted_structured_output_multiple.yaml | 209 --------- .../test_model_structured_output.yaml | 288 ------------ ...test_model_structured_output_multiple.yaml | 444 ------------------ .../test_prompted_structured_output.yaml | 248 ---------- ...t_prompted_structured_output_multiple.yaml | 248 ---------- 19 files changed, 3349 deletions(-) delete mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output.yaml delete mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_model_structured_output.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_model_structured_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_with_tools.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_model_structured_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_prompted_structured_output.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_prompted_structured_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_google/test_google_prompted_structured_output_with_tools.yaml delete mode 100644 tests/models/cassettes/test_openai/test_openai_model_structured_output.yaml delete mode 100644 tests/models/cassettes/test_openai/test_openai_model_structured_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_openai/test_openai_prompted_structured_output.yaml delete mode 100644 tests/models/cassettes/test_openai/test_openai_prompted_structured_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_openai_responses/test_model_structured_output.yaml delete mode 100644 tests/models/cassettes/test_openai_responses/test_model_structured_output_multiple.yaml delete mode 100644 tests/models/cassettes/test_openai_responses/test_prompted_structured_output.yaml delete mode 100644 tests/models/cassettes/test_openai_responses/test_prompted_structured_output_multiple.yaml diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output.yaml deleted file mode 100644 index e88afebdf..000000000 --- a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output.yaml +++ /dev/null @@ -1,161 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '740' - content-type: - - application/json - host: - - api.anthropic.com - method: POST - parsed_body: - max_tokens: 1024 - messages: - - content: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - type: text - role: user - model: claude-3-5-sonnet-latest - stream: false - system: |+ - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - - tool_choice: - type: auto - tools: - - description: '' - input_schema: - additionalProperties: false - properties: {} - type: object - name: get_user_country - uri: https://api.anthropic.com/v1/messages?beta=true - response: - headers: - connection: - - keep-alive - content-length: - - '397' - content-type: - - application/json - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - content: - - id: toolu_017UryVwtsKsjonhFV3cgV3X - input: {} - name: get_user_country - type: tool_use - id: msg_014CpBKzioMqUyLWrMihpvsz - model: claude-3-5-sonnet-20241022 - role: assistant - stop_reason: tool_use - stop_sequence: null - type: message - usage: - cache_creation_input_tokens: 0 - cache_read_input_tokens: 0 - input_tokens: 459 - output_tokens: 38 - service_tier: standard - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1002' - content-type: - - application/json - host: - - api.anthropic.com - method: POST - parsed_body: - max_tokens: 1024 - messages: - - content: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - type: text - role: user - - content: - - id: toolu_017UryVwtsKsjonhFV3cgV3X - input: {} - name: get_user_country - type: tool_use - role: assistant - - content: - - content: Mexico - is_error: false - tool_use_id: toolu_017UryVwtsKsjonhFV3cgV3X - type: tool_result - role: user - model: claude-3-5-sonnet-latest - stream: false - system: |+ - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - - tool_choice: - type: auto - tools: - - description: '' - input_schema: - additionalProperties: false - properties: {} - type: object - name: get_user_country - uri: https://api.anthropic.com/v1/messages?beta=true - response: - headers: - connection: - - keep-alive - content-length: - - '380' - content-type: - - application/json - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - content: - - text: '{"city": "Mexico City", "country": "Mexico"}' - type: text - id: msg_014JeWCouH6DpdqzMTaBdkpJ - model: claude-3-5-sonnet-20241022 - role: assistant - stop_reason: end_turn - stop_sequence: null - type: message - usage: - cache_creation_input_tokens: 0 - cache_read_input_tokens: 0 - input_tokens: 510 - output_tokens: 17 - service_tier: standard - status: - code: 200 - message: OK -version: 1 -... diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output_multiple.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output_multiple.yaml deleted file mode 100644 index 183daa406..000000000 --- a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_structured_output_multiple.yaml +++ /dev/null @@ -1,66 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1268' - content-type: - - application/json - host: - - api.anthropic.com - method: POST - parsed_body: - max_tokens: 1024 - messages: - - content: - - text: What is the largest city in Mexico? - type: text - role: user - model: claude-3-5-sonnet-latest - stream: false - system: |+ - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - - uri: https://api.anthropic.com/v1/messages?beta=true - response: - headers: - connection: - - keep-alive - content-length: - - '434' - content-type: - - application/json - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - content: - - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' - type: text - id: msg_013ttUi3HCcKt7PkJpoWs5FT - model: claude-3-5-sonnet-20241022 - role: assistant - stop_reason: end_turn - stop_sequence: null - type: message - usage: - cache_creation_input_tokens: 0 - cache_read_input_tokens: 0 - input_tokens: 281 - output_tokens: 31 - service_tier: standard - status: - code: 200 - message: OK -version: 1 -... diff --git a/tests/models/cassettes/test_gemini/test_gemini_model_structured_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_model_structured_output.yaml deleted file mode 100644 index d7f14c9ca..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_model_structured_output.yaml +++ /dev/null @@ -1,79 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '305' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - response_mime_type: application/json - response_schema: - properties: - city: - type: string - country: - type: string - required: - - city - - country - title: CityLocation - type: object - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '710' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=819 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.00018302639946341515 - content: - parts: - - text: |- - { - "city": "Mexico City", - "country": "Mexico" - } - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: SEVIaJvJHICK7dcP3OzRiQQ - usageMetadata: - candidatesTokenCount: 20 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 20 - promptTokenCount: 17 - promptTokensDetails: - - modality: TEXT - tokenCount: 17 - totalTokenCount: 37 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_model_structured_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_model_structured_output_multiple.yaml deleted file mode 100644 index 3b306d133..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_model_structured_output_multiple.yaml +++ /dev/null @@ -1,120 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '791' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the primarily language spoken in Mexico? - role: user - generationConfig: - response_mime_type: application/json - response_schema: - properties: - result: - anyOf: - - description: CityLocation - properties: - data: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - enum: - - CityLocation - type: string - required: - - kind - - data - type: object - - description: CountryLanguage - properties: - data: - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - enum: - - CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '800' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=963 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -3.3667640072172103e-06 - content: - parts: - - text: |- - { - "result": { - "data": { - "country": "Mexico", - "language": "Spanish" - }, - "kind": "CountryLanguage" - } - } - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: 2jxIaPucEYCK7dcP3OzRiQQ - usageMetadata: - candidatesTokenCount: 46 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 46 - promptTokenCount: 46 - promptTokensDetails: - - modality: TEXT - tokenCount: 46 - totalTokenCount: 92 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output.yaml deleted file mode 100644 index 2268e7f84..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output.yaml +++ /dev/null @@ -1,74 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '521' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - response_mime_type: application/json - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '880' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=841 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.007913463882037572 - content: - parts: - - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], - "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: 2zxIaIiLE4CK7dcP3OzRiQQ - usageMetadata: - candidatesTokenCount: 56 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 56 - promptTokenCount: 80 - promptTokensDetails: - - modality: TEXT - tokenCount: 80 - totalTokenCount: 136 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_multiple.yaml deleted file mode 100644 index e96fc20d7..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_multiple.yaml +++ /dev/null @@ -1,73 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1287' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - response_mime_type: application/json - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: user - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '757' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=823 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.0030997690779191477 - content: - parts: - - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: Wz1IaOH5OdGU7dcPjoS34QI - usageMetadata: - candidatesTokenCount: 27 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 27 - promptTokenCount: 253 - promptTokensDetails: - - modality: TEXT - tokenCount: 253 - totalTokenCount: 280 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_with_tools.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_with_tools.yaml deleted file mode 100644 index f10da3ad7..000000000 --- a/tests/models/cassettes/test_gemini/test_gemini_prompted_structured_output_with_tools.yaml +++ /dev/null @@ -1,157 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '615' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - role: user - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - tools: - functionDeclarations: - - description: '' - name: get_user_country - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '653' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=4501 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - content: - parts: - - functionCall: - args: {} - name: get_user_country - role: model - finishReason: STOP - index: 0 - modelVersion: models/gemini-2.5-pro-preview-05-06 - responseId: rj9IaPTzNdCBqtsPg-GD6QU - usageMetadata: - candidatesTokenCount: 12 - promptTokenCount: 123 - promptTokensDetails: - - modality: TEXT - tokenCount: 123 - thoughtsTokenCount: 318 - totalTokenCount: 453 - status: - code: 200 - message: OK -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '809' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - role: user - - parts: - - functionCall: - args: {} - name: get_user_country - role: model - - parts: - - functionResponse: - name: get_user_country - response: - return_value: Mexico - role: user - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - tools: - functionDeclarations: - - description: '' - name: get_user_country - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '616' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=1823 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - content: - parts: - - text: '{"city": "Mexico City", "country": "Mexico"}' - role: model - finishReason: STOP - index: 0 - modelVersion: models/gemini-2.5-pro-preview-05-06 - responseId: sD9IaOCyLPqumtkP6p_T0AE - usageMetadata: - candidatesTokenCount: 13 - promptTokenCount: 154 - promptTokensDetails: - - modality: TEXT - tokenCount: 154 - thoughtsTokenCount: 94 - totalTokenCount: 261 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_model_structured_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_model_structured_output_multiple.yaml deleted file mode 100644 index 74dd03c89..000000000 --- a/tests/models/cassettes/test_google/test_google_model_structured_output_multiple.yaml +++ /dev/null @@ -1,138 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1200' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the primarily language spoken in Mexico? - role: user - generationConfig: - responseMimeType: application/json - responseSchema: - description: The final response which ends this conversation - properties: - result: - any_of: - - description: CityLocation - properties: - data: - properties: - city: - type: STRING - country: - type: STRING - property_ordering: - - city - - country - required: - - city - - country - type: OBJECT - kind: - enum: - - CityLocation - type: STRING - property_ordering: - - kind - - data - required: - - kind - - data - type: OBJECT - - description: CountryLanguage - properties: - data: - properties: - country: - type: STRING - language: - type: STRING - property_ordering: - - country - - language - required: - - country - - language - type: OBJECT - kind: - enum: - - CountryLanguage - type: STRING - property_ordering: - - kind - - data - required: - - kind - - data - type: OBJECT - required: - - result - title: final_result - type: OBJECT - toolConfig: - functionCallingConfig: - allowedFunctionNames: [] - mode: ANY - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '800' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=884 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.0002536005138055138 - content: - parts: - - text: |- - { - "result": { - "kind": "CountryLanguage", - "data": { - "country": "Mexico", - "language": "Spanish" - } - } - } - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: W29HaJzGMNGU7dcPjoS34QI - usageMetadata: - candidatesTokenCount: 46 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 46 - promptTokenCount: 64 - promptTokensDetails: - - modality: TEXT - tokenCount: 64 - totalTokenCount: 110 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_structured_output.yaml b/tests/models/cassettes/test_google/test_google_prompted_structured_output.yaml deleted file mode 100644 index 3b241acae..000000000 --- a/tests/models/cassettes/test_google/test_google_prompted_structured_output.yaml +++ /dev/null @@ -1,78 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '619' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - responseMimeType: application/json - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - toolConfig: - functionCallingConfig: - allowedFunctionNames: [] - mode: ANY - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '879' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=829 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.010130892906870161 - content: - parts: - - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], - "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: 4HlHaK75MdGU7dcPjoS34QI - usageMetadata: - candidatesTokenCount: 56 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 56 - promptTokenCount: 80 - promptTokensDetails: - - modality: TEXT - tokenCount: 80 - totalTokenCount: 136 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_structured_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_prompted_structured_output_multiple.yaml deleted file mode 100644 index 33383473f..000000000 --- a/tests/models/cassettes/test_google/test_google_prompted_structured_output_multiple.yaml +++ /dev/null @@ -1,77 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1341' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in Mexico? - role: user - generationConfig: - responseMimeType: application/json - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: user - toolConfig: - functionCallingConfig: - allowedFunctionNames: [] - mode: ANY - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '758' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=734 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - avgLogprobs: -0.0008548707873732956 - content: - parts: - - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' - role: model - finishReason: STOP - modelVersion: gemini-2.0-flash - responseId: 6nlHaO_5GdeI_NUPmYvnoA8 - usageMetadata: - candidatesTokenCount: 27 - candidatesTokensDetails: - - modality: TEXT - tokenCount: 27 - promptTokenCount: 241 - promptTokensDetails: - - modality: TEXT - tokenCount: 241 - totalTokenCount: 268 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_structured_output_with_tools.yaml b/tests/models/cassettes/test_google/test_google_prompted_structured_output_with_tools.yaml deleted file mode 100644 index 976533c66..000000000 --- a/tests/models/cassettes/test_google/test_google_prompted_structured_output_with_tools.yaml +++ /dev/null @@ -1,164 +0,0 @@ -interactions: -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '658' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - role: user - generationConfig: {} - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - tools: - - functionDeclarations: - - description: '' - name: get_user_country - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '653' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=3776 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - content: - parts: - - functionCall: - args: {} - name: get_user_country - role: model - finishReason: STOP - index: 0 - modelVersion: models/gemini-2.5-pro-preview-05-06 - responseId: FnpHaOqcKrzQz7IPkuLo8QE - usageMetadata: - candidatesTokenCount: 12 - promptTokenCount: 123 - promptTokensDetails: - - modality: TEXT - tokenCount: 123 - thoughtsTokenCount: 266 - totalTokenCount: 401 - status: - code: 200 - message: OK -- request: - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '967' - content-type: - - application/json - host: - - generativelanguage.googleapis.com - method: POST - parsed_body: - contents: - - parts: - - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. - role: user - - parts: - - functionCall: - args: {} - id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 - name: get_user_country - role: model - - parts: - - functionResponse: - id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 - name: get_user_country - response: - return_value: Mexico - role: user - generationConfig: {} - systemInstruction: - parts: - - text: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: user - tools: - - functionDeclarations: - - description: '' - name: get_user_country - uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent - response: - headers: - alt-svc: - - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 - content-length: - - '630' - content-type: - - application/json; charset=UTF-8 - server-timing: - - gfet4t7; dur=1888 - transfer-encoding: - - chunked - vary: - - Origin - - X-Origin - - Referer - parsed_body: - candidates: - - content: - parts: - - text: |- - ```json - {"city": "Mexico City", "country": "Mexico"} - ``` - role: model - finishReason: STOP - index: 0 - modelVersion: models/gemini-2.5-pro-preview-05-06 - responseId: GHpHaOPkI43Qz7IPxt6T2Ac - usageMetadata: - candidatesTokenCount: 18 - promptTokenCount: 154 - promptTokensDetails: - - modality: TEXT - tokenCount: 154 - thoughtsTokenCount: 94 - totalTokenCount: 266 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_model_structured_output.yaml b/tests/models/cassettes/test_openai/test_openai_model_structured_output.yaml deleted file mode 100644 index ff4477f3d..000000000 --- a/tests/models/cassettes/test_openai/test_openai_model_structured_output.yaml +++ /dev/null @@ -1,223 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '522' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: What is the largest city in the user country? - role: user - model: gpt-4o - n: 1 - response_format: - json_schema: - name: result - schema: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: false - type: json_schema - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1066' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '341' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_PkRGedQNRFUzJp2R7dO7avWR - type: function - created: 1746142582 - id: chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3 - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_f5bdcc3276 - usage: - completion_tokens: 12 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 71 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 83 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '753' - content-type: - - application/json - cookie: - - __cf_bm=dOa3_E1SoWV.vbgZ8L7tx8o9S.XyNTE.YS0K0I3JHq4-1746142583-1.0.1.1-0TuvhdYsoD.J1522DBXH0yrAP_M9MlzvlcpyfwQQNZy.KO5gri6ejQ.gFuwLV5hGhuY0W2uI1dN7ZF1lirVHKeEnEz5s_89aJjrMWjyBd8M; - _cfuvid=xQIJVHkOP28w5fPnAvDHPiCRlU7kkNj6iFV87W4u8Ds-1746142583128-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: What is the largest city in the user country? - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_PkRGedQNRFUzJp2R7dO7avWR - type: function - - content: Mexico - role: tool - tool_call_id: call_PkRGedQNRFUzJp2R7dO7avWR - model: gpt-4o - n: 1 - response_format: - json_schema: - name: result - schema: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: false - type: json_schema - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '852' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '553' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: '{"city":"Mexico City","country":"Mexico"}' - refusal: null - role: assistant - created: 1746142583 - id: chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_f5bdcc3276 - usage: - completion_tokens: 15 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 92 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 107 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_model_structured_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_model_structured_output_multiple.yaml deleted file mode 100644 index d01e28ab0..000000000 --- a/tests/models/cassettes/test_openai/test_openai_model_structured_output_multiple.yaml +++ /dev/null @@ -1,293 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1120' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: What is the largest city in the user country? - role: user - model: gpt-4o - response_format: - json_schema: - description: The final response which ends this conversation - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - required: - - kind - - data - type: object - required: - - result - type: object - type: json_schema - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1068' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '868' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_SIttSeiOistt33Htj4oiHOOX - type: function - created: 1749511286 - id: chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_9bddfca6e2 - usage: - completion_tokens: 11 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 160 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 171 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1351' - content-type: - - application/json - cookie: - - __cf_bm=OFzdr.HrmtC0DNdnfrTQYsK8_PwAVR9GUqjYSCgwtVM-1749511286-1.0.1.1-9_dbth7ET4rzl01UDRTw3fY1nJ20FnMCC0BBmd57gzKF8n5DnNQaI4K1mT.23nn9IUsMyHAZUNn6t1EML3d7GfGJyiLZOxrTWaqacALgzlM; - _cfuvid=f32dQYPsRd6Jc7kg.3hHa1QYAyG8f_aMMXUF.bC6gmY-1749511286914-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: What is the largest city in the user country? - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_SIttSeiOistt33Htj4oiHOOX - type: function - - content: Mexico - role: tool - tool_call_id: call_SIttSeiOistt33Htj4oiHOOX - model: gpt-4o - response_format: - json_schema: - description: The final response which ends this conversation - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - required: - - kind - - data - type: object - required: - - result - type: object - type: json_schema - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '903' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '920' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' - refusal: null - role: assistant - created: 1749511287 - id: chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_9bddfca6e2 - usage: - completion_tokens: 25 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 181 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 206 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_prompted_structured_output.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_structured_output.yaml deleted file mode 100644 index 4eed79085..000000000 --- a/tests/models/cassettes/test_openai/test_openai_prompted_structured_output.yaml +++ /dev/null @@ -1,209 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '690' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - model: gpt-4o - response_format: - type: json_object - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1068' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '569' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_s7oT9jaLAsEqTgvxZTmFh0wB - type: function - created: 1749514895 - id: chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_07871e2ad8 - usage: - completion_tokens: 11 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 109 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 120 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '921' - content-type: - - application/json - cookie: - - __cf_bm=jcec.FXQ2vs1UTNFhcDbuMrvzdFu7d7L1To24_vRFiQ-1749514896-1.0.1.1-PEeul2ZYkvLFmEXXk4Xlgvun2HcuGEJ0UUliLVWKx17kMCjZ8WiZbB2Yavq3RRGlxsJZsAWIVMQQ10Vb_2aqGVtQ2aiYTlnDMX3Ktkuciyk; - _cfuvid=zanrNpp5OAiS0wLKfkW9LCs3qTO2FvIaiBZptR_D2P0-1749514896187-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_s7oT9jaLAsEqTgvxZTmFh0wB - type: function - - content: Mexico - role: tool - tool_call_id: call_s7oT9jaLAsEqTgvxZTmFh0wB - model: gpt-4o - response_format: - type: json_object - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '853' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '718' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: '{"city":"Mexico City","country":"Mexico"}' - refusal: null - role: assistant - created: 1749514896 - id: chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0 - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_07871e2ad8 - usage: - completion_tokens: 11 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 130 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 141 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_prompted_structured_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_structured_output_multiple.yaml deleted file mode 100644 index 3d3ba886a..000000000 --- a/tests/models/cassettes/test_openai/test_openai_prompted_structured_output_multiple.yaml +++ /dev/null @@ -1,209 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1412' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - model: gpt-4o - response_format: - type: json_object - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1068' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '428' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_wJD14IyJ4KKVtjCrGyNCHO09 - type: function - created: 1749514898 - id: chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_9bddfca6e2 - usage: - completion_tokens: 11 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 273 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 284 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1643' - content-type: - - application/json - cookie: - - __cf_bm=gqjIEMZSez95CPkkPVuU_AoDutHrobFMbFPjq43G66M-1749514899-1.0.1.1-5TGB9WajW5pzCRtVtWeQfiwyQUZs1JwWy9qC8VGlgq7s5pQWKerukQtYB7GqNDrdb.1pbtFyt2HZ9xV3YiSbK4H1bZS_hS1CCeoup_3IQW0; - _cfuvid=ZN6eoNau4b.bJ8kvRn2z9R0HgTUd9nOsupKUtLXQowU-1749514899280-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_user_country - id: call_wJD14IyJ4KKVtjCrGyNCHO09 - type: function - - content: Mexico - role: tool - tool_call_id: call_wJD14IyJ4KKVtjCrGyNCHO09 - model: gpt-4o - response_format: - type: json_object - stream: false - tool_choice: auto - tools: - - function: - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '903' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '763' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' - refusal: null - role: assistant - created: 1749514899 - id: chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_9bddfca6e2 - usage: - completion_tokens: 21 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 294 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 315 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_model_structured_output.yaml b/tests/models/cassettes/test_openai_responses/test_model_structured_output.yaml deleted file mode 100644 index 9fd1b6989..000000000 --- a/tests/models/cassettes/test_openai_responses/test_model_structured_output.yaml +++ /dev/null @@ -1,288 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '533' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: What is the largest city in the user country? - role: user - model: gpt-4o - stream: false - text: - format: - name: CityLocation - schema: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1808' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '636' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749516047 - error: null - id: resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - arguments: '{}' - call_id: call_tTAThu8l2S9hNky2krdwijGP - id: fc_68477f0fa7c081a19a525f7c6f180f310b8591d9001d2329 - name: get_user_country - status: completed - type: function_call - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - description: null - name: CityLocation - schema: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 66 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 78 - user: null - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '769' - content-type: - - application/json - cookie: - - __cf_bm=My3TWVEPFsaYcjJ.iWxTB6P67jFSuxSF.n13qHpH9BA-1749516047-1.0.1.1-2bg2ltV1yu2uhfqewI9eEG1ulzfU_gq8pLx9YwHte33BTk2PgxBwaRdyegdEs_dVkAbaCoAPsQRIQmW21QPf_U2Fd1vdibnoExA_.rvTYv8; - _cfuvid=_7XoQBGwU.UsQgiPHVWMTXLLbADtbSwhrO9PY7I_3Dw-1749516047790-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: What is the largest city in the user country? - role: user - - content: '' - role: assistant - - arguments: '{}' - call_id: call_tTAThu8l2S9hNky2krdwijGP - name: get_user_country - type: function_call - - call_id: call_tTAThu8l2S9hNky2krdwijGP - output: Mexico - type: function_call_output - model: gpt-4o - stream: false - text: - format: - name: CityLocation - schema: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1902' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '883' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749516047 - error: null - id: resp_68477f0fde708192989000a62809c6e5020197534e39cc1f - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - content: - - annotations: [] - text: '{"city":"Mexico City","country":"Mexico"}' - type: output_text - id: msg_68477f10846c81929f1e833b0785e6f3020197534e39cc1f - role: assistant - status: completed - type: message - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - description: null - name: CityLocation - schema: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 89 - input_tokens_details: - cached_tokens: 0 - output_tokens: 16 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 105 - user: null - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_model_structured_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_model_structured_output_multiple.yaml deleted file mode 100644 index 9c411f3c7..000000000 --- a/tests/models/cassettes/test_openai_responses/test_model_structured_output_multiple.yaml +++ /dev/null @@ -1,444 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1143' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: What is the largest city in the user country? - role: user - model: gpt-4o - stream: false - text: - format: - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - type: string - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - additionalProperties: false - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '3657' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '562' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749516048 - error: null - id: resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - arguments: '{}' - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP - id: fc_68477f1168a081a3981e847cd94275080dd57d732903c563 - name: get_user_country - status: completed - type: function_call - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - description: null - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - type: string - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - additionalProperties: false - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 153 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 165 - user: null - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1379' - content-type: - - application/json - cookie: - - __cf_bm=3Nl1ERbtfVAI.dGjzCYYN1u71YD5eEoLU0iCrvPPPL0-1749516049-1.0.1.1-LnI7tJwKr.C_wA15Shsl8pcGd32zrRqqv_9u4S84nXtNCopx1iBIKYDsyMg3u1Z3lJ_1Cd1YVM8uKAMjiKmgoqS8GFQ3Z_vV_Mahvqbi4KA; - _cfuvid=oc_k9l86fnMo2ml.0aop6a3eVJEvjxB0lnxWK0_kJq8-1749516049524-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: What is the largest city in the user country? - role: user - - content: '' - role: assistant - - arguments: '{}' - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP - name: get_user_country - type: function_call - - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP - output: Mexico - type: function_call_output - model: gpt-4o - stream: false - text: - format: - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - type: string - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - additionalProperties: false - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '3800' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '1042' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749516049 - error: null - id: resp_68477f119830819da162aa6e10552035061ad97e2eef7871 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - content: - - annotations: [] - text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' - type: output_text - id: msg_68477f1235b8819d898adc64709c7ebf061ad97e2eef7871 - role: assistant - status: completed - type: message - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - description: null - name: final_result - schema: - additionalProperties: false - properties: - result: - anyOf: - - additionalProperties: false - description: CityLocation - properties: - data: - additionalProperties: false - properties: - city: - type: string - country: - type: string - required: - - city - - country - type: object - kind: - const: CityLocation - type: string - required: - - kind - - data - type: object - - additionalProperties: false - description: CountryLanguage - properties: - data: - additionalProperties: false - properties: - country: - type: string - language: - type: string - required: - - country - - language - type: object - kind: - const: CountryLanguage - type: string - required: - - kind - - data - type: object - required: - - result - type: object - strict: true - type: json_schema - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 176 - input_tokens_details: - cached_tokens: 0 - output_tokens: 26 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 202 - user: null - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_structured_output.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_structured_output.yaml deleted file mode 100644 index 35783c516..000000000 --- a/tests/models/cassettes/test_openai_responses/test_prompted_structured_output.yaml +++ /dev/null @@ -1,248 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '689' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - model: gpt-4o - stream: false - text: - format: - type: json_object - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1408' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '8314' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749561106 - error: null - id: resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - arguments: '{}' - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom - id: fc_68482f1b0ff081a1b37b9170ee740d1e02f8ef7f2fb42b50 - name: get_user_country - status: completed - type: function_call - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - type: json_object - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 107 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 119 - user: null - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '925' - content-type: - - application/json - cookie: - - __cf_bm=8a8rNQQYozQt3YjcA61k6KGe.AlrMMrtcIvKv.D1s1E-1749561115-1.0.1.1-OFcqg8xD2_HdbeO74bU2.mLTqDuiK.ploHeu3_ITPvDlGwrVkwk8erMkHagxk4UDxACCCAygnUs1HL.F4AGjQCaZm1m2eYiMVbLqp0iQh7g; - _cfuvid=wKTRRc2dbdYNYnYwA2vRxVjUvqqkQovvKDwULW0Xwns-1749561115173-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - - content: '' - role: assistant - - arguments: '{}' - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom - name: get_user_country - type: function_call - - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom - output: Mexico - type: function_call_output - model: gpt-4o - stream: false - text: - format: - type: json_object - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1501' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '1098' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749561115 - error: null - id: resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - content: - - annotations: [] - text: '{"city":"Mexico City","country":"Mexico"}' - type: output_text - id: msg_68482f1c159081918a2405f458009a6a044fdb7d019d4115 - role: assistant - status: completed - type: message - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - type: json_object - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 130 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 142 - user: null - status: - code: 200 - message: OK -version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_structured_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_structured_output_multiple.yaml deleted file mode 100644 index 1a3b4dc00..000000000 --- a/tests/models/cassettes/test_openai_responses/test_prompted_structured_output_multiple.yaml +++ /dev/null @@ -1,248 +0,0 @@ -interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1455' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - model: gpt-4o - stream: false - text: - format: - type: json_object - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1408' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '11445' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749561117 - error: null - id: resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - arguments: '{}' - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI - id: fc_68482f2889d481a199caa61de7ccb62c08e79646fe74d5ee - name: get_user_country - status: completed - type: function_call - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - type: json_object - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 283 - input_tokens_details: - cached_tokens: 0 - output_tokens: 12 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 295 - user: null - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1691' - content-type: - - application/json - cookie: - - __cf_bm=l95LdgPzGHw0UAhBwse9ADphgmMDWrhYqgiO4gdmSy4-1749561128-1.0.1.1-9zPIs3d5_ipszLpQ7yBaCZEStp8qoRIGFshR93V6n7Z_7AznH0MfuczwuoiaW8e6cEVeVHLhskjXScolO9gP5TmpsaFo37GRuHsHZTRgEeI; - _cfuvid=5L5qtbtbFCFzMmoVufSY.ksn06ay8AFs.UXFEv07pkY-1749561128680-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - input: - - content: |- - Always respond with a JSON object that's compatible with this schema: - - {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} - - Don't include any text or Markdown fencing before or after. - role: system - - content: What is the largest city in the user country? - role: user - - content: '' - role: assistant - - arguments: '{}' - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI - name: get_user_country - type: function_call - - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI - output: Mexico - type: function_call_output - model: gpt-4o - stream: false - text: - format: - type: json_object - tool_choice: auto - tools: - - description: '' - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - uri: https://api.openai.com/v1/responses - response: - headers: - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1551' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '2545' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - background: false - created_at: 1749561128 - error: null - id: resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024 - incomplete_details: null - instructions: null - max_output_tokens: null - metadata: {} - model: gpt-4o-2024-08-06 - object: response - output: - - content: - - annotations: [] - text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' - type: output_text - id: msg_68482f296bfc81a18665547d4008ab2c06b4ab2d00d03024 - role: assistant - status: completed - type: message - parallel_tool_calls: true - previous_response_id: null - reasoning: - effort: null - summary: null - service_tier: default - status: completed - store: true - temperature: 1.0 - text: - format: - type: json_object - tool_choice: auto - tools: - - description: null - name: get_user_country - parameters: - additionalProperties: false - properties: {} - type: object - strict: false - type: function - top_p: 1.0 - truncation: disabled - usage: - input_tokens: 306 - input_tokens_details: - cached_tokens: 0 - output_tokens: 22 - output_tokens_details: - reasoning_tokens: 0 - total_tokens: 328 - user: null - status: - code: 200 - message: OK -version: 1 From 8745a7ac506ec0372ef2adbc8ac3ea59d9302c2d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 26 Jun 2025 00:38:31 +0000 Subject: [PATCH 58/90] Pass just one toolset into the run --- docs/agents.md | 2 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 23 +++++++++++--------- pydantic_ai_slim/pydantic_ai/agent.py | 8 ++----- pydantic_ai_slim/pydantic_ai/toolset.py | 9 ++++++-- tests/models/test_model_test.py | 4 ++-- tests/test_agent.py | 4 ++-- tests/test_tools.py | 2 +- 7 files changed, 28 insertions(+), 24 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 3fe60f7eb..732039307 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -808,7 +808,7 @@ with capture_run_messages() as messages: # (2)! result = agent.run_sync('Please get me the volume of a box with size 6.') except UnexpectedModelBehavior as e: print('An error occurred:', e) - #> An error occurred: Tool exceeded max retries count of 1 + #> An error occurred: Tool 'calc_volume' exceeded max retries count of 1 print('cause:', repr(e.__cause__)) #> cause: ModelRetry('Please try again.') print('messages:', messages) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 3e3eb691b..4168be86f 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -17,7 +17,7 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._utils import is_async_callable, run_in_executor -from pydantic_ai.toolset import AbstractToolset, CombinedToolset, RunToolset +from pydantic_ai.toolset import AbstractToolset, RunToolset from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT @@ -107,7 +107,6 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] output_schema: _output.OutputSchema[OutputDataT] - output_toolset: RunToolset[DepsT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] history_processors: Sequence[HistoryProcessor[DepsT]] @@ -249,11 +248,7 @@ async def _prepare_request_parameters( ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" run_context = build_run_context(ctx) - ctx.deps.toolset = toolset = await ctx.deps.toolset.prepare_for_run(run_context) - ctx.deps.output_toolset = output_toolset = await ctx.deps.output_toolset.prepare_for_run(run_context) - - # This will raise errors for any name conflicts - CombinedToolset[DepsT]([output_toolset, toolset]) + ctx.deps.toolset = await ctx.deps.toolset.prepare_for_run(run_context) output_schema = ctx.deps.output_schema output_object = None @@ -263,10 +258,18 @@ async def _prepare_request_parameters( # ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema allow_text_output = isinstance(output_schema, _output.TextOutputSchema) + function_tools: list[ToolDefinition] = [] + output_tools: list[ToolDefinition] = [] + for tool_def in ctx.deps.toolset.tool_defs: + if tool_def.kind == 'output': + output_tools.append(tool_def) + else: + function_tools.append(tool_def) + return models.ModelRequestParameters( - function_tools=toolset.tool_defs, + function_tools=function_tools, output_mode=output_schema.mode, - output_tools=output_toolset.tool_defs, + output_tools=output_tools, output_object=output_object, allow_text_output=allow_text_output, ) @@ -487,7 +490,7 @@ async def _handle_tool_calls( # noqa: C901 final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] - toolset = CombinedToolset([ctx.deps.toolset, ctx.deps.output_toolset]) + toolset = ctx.deps.toolset unknown_calls: list[_messages.ToolCallPart] = [] tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index f67320fc3..2c5d37ca5 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -686,11 +686,8 @@ async def main(): run_step=state.run_step, ) - run_toolset = await self._toolset.prepare_for_run(run_context) - run_output_toolset = await output_toolset.prepare_for_run(run_context) - - # This will raise errors for any name conflicts - CombinedToolset([run_output_toolset, run_toolset]) + toolset = CombinedToolset([output_toolset, self._toolset]) + run_toolset = await toolset.prepare_for_run(run_context) model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or _usage.UsageLimits() @@ -738,7 +735,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: end_strategy=self.end_strategy, output_schema=output_schema, output_validators=output_validators, - output_toolset=run_output_toolset, history_processors=self.history_processors, toolset=run_toolset, tracer=tracer, diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index 01a9400c3..7fa274fec 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -368,6 +368,11 @@ def tool_names(self) -> list[str]: def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: return self._toolset_for_tool_name(name).get_tool_args_validator(ctx, name) + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False + ) -> dict[str, Any]: + return self._toolset_for_tool_name(name).validate_tool_args(ctx, name, args, allow_partial) + def max_retries_for_tool(self, name: str) -> int: return self._toolset_for_tool_name(name).max_retries_for_tool(name) @@ -678,9 +683,9 @@ def _on_error(self, name: str, e: Exception) -> Never: max_retries = self.max_retries_for_tool(name) current_retry = self._retries.get(name, 0) if current_retry == max_retries: - raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {max_retries}') from e + raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e else: - self._retries[name] = current_retry + 1 + self._retries[name] = current_retry + 1 # TODO: Reset on successful call! raise e def _validate_tool_name(self, name: str) -> None: diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 9fd95feaf..9d022641f 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -157,7 +157,7 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel: call_count += 1 raise ModelRetry('Fail') - with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + with pytest.raises(UnexpectedModelBehavior, match="Tool 'final_result' exceeded max retries count of 2"): agent.run_sync('Hello', model=TestModel()) assert call_count == 3 @@ -200,7 +200,7 @@ class ResultModel(BaseModel): agent = Agent('test', output_type=ResultModel, retries=2) - with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + with pytest.raises(UnexpectedModelBehavior, match="Tool 'final_result' exceeded max retries count of 2"): agent.run_sync('Hello', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1})) diff --git a/tests/test_agent.py b/tests/test_agent.py index b291c0ac8..4bbd649a4 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2293,7 +2293,7 @@ def another_tool(y: int) -> int: tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ), RetryPromptPart( - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", + content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsDatetime(), @@ -2380,7 +2380,7 @@ def another_tool(y: int) -> int: # pragma: no cover ), RetryPromptPart( tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", + content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool", timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), diff --git a/tests/test_tools.py b/tests/test_tools.py index a31369db4..c0e9d57e3 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1171,7 +1171,7 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int: call_retries.append(ctx.retry) raise ModelRetry('Please try again.') - with pytest.raises(UnexpectedModelBehavior, match='Tool exceeded max retries count of 5'): + with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"): agent.run_sync('Begin infinite retry loop!') # There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in. From 05aa97250f09275d895fe29d3106261c2c494286 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 26 Jun 2025 20:00:20 +0000 Subject: [PATCH 59/90] WIP --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 233 ++++++++++--------- pydantic_ai_slim/pydantic_ai/agent.py | 15 +- pydantic_ai_slim/pydantic_ai/result.py | 3 + 3 files changed, 127 insertions(+), 124 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 4168be86f..84fac934e 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -480,109 +480,21 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: async for event in self._events_iterator: yield event - async def _handle_tool_calls( # noqa: C901 + async def _handle_tool_calls( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], tool_calls: list[_messages.ToolCallPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: run_context = build_run_context(ctx) - final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] + final_result_holder: list[result.FinalResult[NodeRunEndT]] = [] - toolset = ctx.deps.toolset - - unknown_calls: list[_messages.ToolCallPart] = [] - tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list) - # TODO: Make Toolset.tool_defs a dict - tool_defs_by_name: dict[str, ToolDefinition] = {tool_def.name: tool_def for tool_def in toolset.tool_defs} - for call in tool_calls: - try: - tool_def = tool_defs_by_name[call.tool_name] - tool_calls_by_kind[tool_def.kind].append(call) - except KeyError: - unknown_calls.append(call) - - # first, look for the output tool call - for call in tool_calls_by_kind['output']: - if final_result: - part = _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Output tool not used - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - parts.append(part) - else: - try: - result_data = await _call_tool(toolset, call, run_context) - except _output.ToolRetryError as e: - parts.append(e.tool_retry) - else: - part = _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Final result processed.', - tool_call_id=call.tool_call_id, - ) - parts.append(part) - final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - - # Then build the other request parts based on end strategy - if final_result and ctx.deps.end_strategy == 'early': - for call in tool_calls_by_kind['function']: - parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - async for event in process_function_tools( - toolset, - tool_calls_by_kind['function'], - ctx, - parts, - ): - yield event - - if unknown_calls: - ctx.state.increment_retries(ctx.deps.max_result_retries) - async for event in process_function_tools( - toolset, - unknown_calls, - ctx, - parts, - ): - yield event - - deferred_calls: list[_messages.ToolCallPart] = [] - for call in tool_calls_by_kind['deferred']: - if final_result: - parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - yield _messages.FunctionToolCallEvent(call) - deferred_calls.append(call) - - if deferred_calls: - if not ctx.deps.output_schema.deferred_tool_calls: - raise exceptions.UserError( - 'There are pending tool calls but DeferredToolCalls is not among output types.' - ) - - deferred_tool_names = [call.tool_name for call in deferred_calls] - deferred_tool_defs = { - tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in deferred_tool_names - } - output_data = cast(NodeRunEndT, DeferredToolCalls(deferred_calls, deferred_tool_defs)) - final_result = result.FinalResult(output_data) + async for event in process_function_tools(ctx.deps.toolset, tool_calls, None, ctx, parts, final_result_holder): + yield event - if final_result: + if final_result_holder: + final_result = final_result_holder[0] self._next_node = self._handle_final_result(ctx, final_result, parts) else: instructions = await ctx.deps.get_instructions(run_context) @@ -652,24 +564,85 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str: async def process_function_tools( toolset: AbstractToolset[DepsT], tool_calls: list[_messages.ToolCallPart], + final_result: result.FinalResult[NodeRunEndT] | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], - output_parts: list[_messages.ModelRequestPart], + parts: list[_messages.ModelRequestPart], + final_result_holder: list[result.FinalResult[NodeRunEndT]] = [], ) -> AsyncIterator[_messages.HandleResponseEvent]: """Process function (i.e., non-result) tool calls in parallel. Also add stub return parts for any other tools that need it. - Because async iterators can't have return values, we use `output_parts` as an output argument. + Because async iterators can't have return values, we use `parts` as an output argument. """ run_context = build_run_context(ctx) - calls_to_run: list[_messages.ToolCallPart] = [] - call_index_to_event_id: dict[int, str] = {} + unknown_calls: list[_messages.ToolCallPart] = [] + tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list) + # TODO: Make Toolset.tool_defs a dict + tool_defs_by_name: dict[str, ToolDefinition] = {tool_def.name: tool_def for tool_def in toolset.tool_defs} for call in tool_calls: - event = _messages.FunctionToolCallEvent(call) - yield event - call_index_to_event_id[len(calls_to_run)] = event.call_id - calls_to_run.append(call) + try: + tool_def = tool_defs_by_name[call.tool_name] + tool_calls_by_kind[tool_def.kind].append(call) + except KeyError: + unknown_calls.append(call) + + # first, look for the output tool call + for call in tool_calls_by_kind['output']: + if final_result: + if final_result.tool_call_id == call.tool_call_id: + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Final result processed.', + tool_call_id=call.tool_call_id, + ) + else: + yield _messages.FunctionToolCallEvent(call) + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Output tool not used - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id) + + parts.append(part) + else: + try: + result_data = await _call_tool(toolset, call, run_context) + except _output.ToolRetryError as e: + yield _messages.FunctionToolCallEvent(call) + parts.append(e.tool_retry) + yield _messages.FunctionToolResultEvent(e.tool_retry, tool_call_id=call.tool_call_id) + else: + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Final result processed.', + tool_call_id=call.tool_call_id, + ) + parts.append(part) + final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) + + calls_to_run: list[_messages.ToolCallPart] = [] + # Then build the other request parts based on end strategy + if final_result and ctx.deps.end_strategy == 'early': + for call in tool_calls_by_kind['function']: + parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + calls_to_run.extend(tool_calls_by_kind['function']) + + if unknown_calls: + ctx.state.increment_retries(ctx.deps.max_result_retries) + calls_to_run.extend(unknown_calls) + + for call in calls_to_run: + yield _messages.FunctionToolCallEvent(call) user_parts: list[_messages.UserPromptPart] = [] @@ -698,12 +671,12 @@ async def process_function_tools( done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: index = tasks.index(task) - result = task.result() - yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index]) + tool_result = task.result() + yield _messages.FunctionToolResultEvent(tool_result, tool_call_id=tool_result.tool_call_id) - if isinstance(result, _messages.RetryPromptPart): - results_by_index[index] = result - elif isinstance(result, _messages.ToolReturnPart): + if isinstance(tool_result, _messages.RetryPromptPart): + results_by_index[index] = tool_result + elif isinstance(tool_result, _messages.ToolReturnPart): def process_content(content: Any) -> Any: if isinstance(content, _messages.MultiModalContentTypes): @@ -715,7 +688,7 @@ def process_content(content: Any) -> Any: user_parts.append( _messages.UserPromptPart( content=[f'This is file {identifier}:', content], - timestamp=result.timestamp, + timestamp=tool_result.timestamp, part_kind='user-prompt', ) ) @@ -723,22 +696,50 @@ def process_content(content: Any) -> Any: else: return content - if isinstance(result.content, list): - contents = cast(list[Any], result.content) # type: ignore - result.content = [process_content(content) for content in contents] + if isinstance(tool_result.content, list): + contents = cast(list[Any], tool_result.content) # type: ignore + tool_result.content = [process_content(content) for content in contents] else: - result.content = process_content(result.content) + tool_result.content = process_content(tool_result.content) - results_by_index[index] = result + results_by_index[index] = tool_result else: - assert_never(result) + assert_never(tool_result) # We append the results at the end, rather than as they are received, to retain a consistent ordering # This is mostly just to simplify testing for k in sorted(results_by_index): - output_parts.append(results_by_index[k]) + parts.append(results_by_index[k]) + + deferred_calls: list[_messages.ToolCallPart] = [] + for call in tool_calls_by_kind['deferred']: + if final_result: + parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + yield _messages.FunctionToolCallEvent(call) + deferred_calls.append(call) + + if deferred_calls: + if not ctx.deps.output_schema.deferred_tool_calls: + raise exceptions.UserError('There are pending tool calls but DeferredToolCalls is not among output types.') + + deferred_tool_names = [call.tool_name for call in deferred_calls] + deferred_tool_defs = { + tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in deferred_tool_names + } + output_data = cast(NodeRunEndT, DeferredToolCalls(deferred_calls, deferred_tool_defs)) + final_result = result.FinalResult(output_data) + + parts.extend(user_parts) - output_parts.extend(user_parts) + if final_result: + final_result_holder.append(final_result) async def _call_function_tool( diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 2c5d37ca5..6c5643e7a 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1046,10 +1046,11 @@ async def stream_to_final( ): # pragma: no branch for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) + # TODO: Handle DeferredToolCalls return None - final_result_details = await stream_to_final(streamed_response) - if final_result_details is not None: + final_result = await stream_to_final(streamed_response) + if final_result is not None: if yielded: raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover yielded = True @@ -1068,19 +1069,16 @@ async def on_complete() -> None: part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) ] + # TODO: Should we move on to the CallToolsNode here, instead of doing this ourselves? parts: list[_messages.ModelRequestPart] = [] - # TODO: Make this work again. We may have pulled too much out of process_function_tools :) async for _event in _agent_graph.process_function_tools( graph_ctx.deps.toolset, tool_calls, + final_result, graph_ctx, parts, ): pass - # TODO: Should we do something here related to the retry count? - # Maybe we should move the incrementing of the retry count to where we actually make a request? - # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): - # ctx.state.increment_retries(ctx.deps.max_result_retries) if parts: messages.append(_messages.ModelRequest(parts)) @@ -1092,10 +1090,11 @@ async def on_complete() -> None: graph_ctx.deps.output_schema, _agent_graph.build_run_context(graph_ctx), graph_ctx.deps.output_validators, - final_result_details.tool_name, + final_result.tool_name, on_complete, ) break + # TODO: There may be deferred tool calls, process those. next_node = await agent_run.next(node) if not isinstance(next_node, _agent_graph.AgentNode): raise exceptions.AgentRunError( # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index d7d6a51cb..64da8b475 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -113,6 +113,8 @@ async def _validate_response( 'Invalid response, unable to process text output' ) + # TODO: Possibly return DeferredToolCalls here? + for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) return result_data @@ -380,6 +382,7 @@ async def get_output(self) -> OutputDataT: pass message = self._stream_response.get() await self._marked_completed(message) + # TODO: Possibly return DeferredToolCalls here? return await self.validate_structured_output(message) @deprecated('`get_data` is deprecated, use `get_output` instead.') From ad6e8262a3a17e908f6980898c6b947a9a1c81a8 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 27 Jun 2025 00:31:04 +0000 Subject: [PATCH 60/90] Fix streaming tool calls --- docs/output.md | 1 - pydantic_ai_slim/pydantic_ai/_agent_graph.py | 119 ++++++++++--------- tests/models/test_anthropic.py | 2 +- tests/test_logfire.py | 1 + tests/test_streaming.py | 59 ++++----- 5 files changed, 93 insertions(+), 89 deletions(-) diff --git a/docs/output.md b/docs/output.md index 08c5a00ff..f32e403d7 100644 --- a/docs/output.md +++ b/docs/output.md @@ -139,7 +139,6 @@ from pydantic import BaseModel from pydantic_ai import Agent, ModelRetry, RunContext from pydantic_ai.exceptions import UnexpectedModelBehavior -from pydantic_ai.output import ToolRetryError class Row(BaseModel): diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 84fac934e..2588721ec 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -561,7 +561,7 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str: return hashlib.sha1(identifier).hexdigest()[:6] -async def process_function_tools( +async def process_function_tools( # noqa: C901 toolset: AbstractToolset[DepsT], tool_calls: list[_messages.ToolCallPart], final_result: result.FinalResult[NodeRunEndT] | None, @@ -646,70 +646,72 @@ async def process_function_tools( user_parts: list[_messages.UserPromptPart] = [] - include_content = ( - ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content - ) + if calls_to_run: + include_content = ( + ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content + ) - # Run all tool tasks in parallel - results_by_index: dict[int, _messages.ModelRequestPart] = {} - with ctx.deps.tracer.start_as_current_span( - 'running tools', - attributes={ - 'tools': [call.tool_name for call in calls_to_run], - 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', - }, - ): - tasks = [ - asyncio.create_task( - _call_function_tool(toolset, call, run_context, ctx.deps.tracer, include_content), name=call.tool_name - ) - for call in calls_to_run - ] - - pending = tasks - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - index = tasks.index(task) - tool_result = task.result() - yield _messages.FunctionToolResultEvent(tool_result, tool_call_id=tool_result.tool_call_id) - - if isinstance(tool_result, _messages.RetryPromptPart): - results_by_index[index] = tool_result - elif isinstance(tool_result, _messages.ToolReturnPart): - - def process_content(content: Any) -> Any: - if isinstance(content, _messages.MultiModalContentTypes): - if isinstance(content, _messages.BinaryContent): - identifier = multi_modal_content_identifier(content.data) + # Run all tool tasks in parallel + results_by_index: dict[int, _messages.ModelRequestPart] = {} + with ctx.deps.tracer.start_as_current_span( + 'running tools', + attributes={ + 'tools': [call.tool_name for call in calls_to_run], + 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', + }, + ): + tasks = [ + asyncio.create_task( + _call_function_tool(toolset, call, run_context, ctx.deps.tracer, include_content), + name=call.tool_name, + ) + for call in calls_to_run + ] + + pending = tasks + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + index = tasks.index(task) + tool_result = task.result() + yield _messages.FunctionToolResultEvent(tool_result, tool_call_id=tool_result.tool_call_id) + + if isinstance(tool_result, _messages.RetryPromptPart): + results_by_index[index] = tool_result + elif isinstance(tool_result, _messages.ToolReturnPart): + + def process_content(content: Any) -> Any: + if isinstance(content, _messages.MultiModalContentTypes): + if isinstance(content, _messages.BinaryContent): + identifier = multi_modal_content_identifier(content.data) + else: + identifier = multi_modal_content_identifier(content.url) + + user_parts.append( + _messages.UserPromptPart( + content=[f'This is file {identifier}:', content], + timestamp=tool_result.timestamp, + part_kind='user-prompt', + ) + ) + return f'See file {identifier}' else: - identifier = multi_modal_content_identifier(content.url) + return content - user_parts.append( - _messages.UserPromptPart( - content=[f'This is file {identifier}:', content], - timestamp=tool_result.timestamp, - part_kind='user-prompt', - ) - ) - return f'See file {identifier}' + if isinstance(tool_result.content, list): + contents = cast(list[Any], tool_result.content) # type: ignore + tool_result.content = [process_content(content) for content in contents] else: - return content + tool_result.content = process_content(tool_result.content) - if isinstance(tool_result.content, list): - contents = cast(list[Any], tool_result.content) # type: ignore - tool_result.content = [process_content(content) for content in contents] + results_by_index[index] = tool_result else: - tool_result.content = process_content(tool_result.content) - - results_by_index[index] = tool_result - else: - assert_never(tool_result) + assert_never(tool_result) - # We append the results at the end, rather than as they are received, to retain a consistent ordering - # This is mostly just to simplify testing - for k in sorted(results_by_index): - parts.append(results_by_index[k]) + # We append the results at the end, rather than as they are received, to retain a consistent ordering + # This is mostly just to simplify testing + for k in sorted(results_by_index): + parts.append(results_by_index[k]) deferred_calls: list[_messages.ToolCallPart] = [] for call in tool_calls_by_kind['deferred']: @@ -739,6 +741,7 @@ def process_content(content: Any) -> Any: parts.extend(user_parts) if final_result: + # TODO: Use some better "box" object final_result_holder.append(final_result) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 3891c5108..77857e882 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1700,7 +1700,7 @@ class CityLocation(BaseModel): agent = Agent(m, output_type=NativeOutput(CityLocation)) - with pytest.raises(UserError, match='Structured output is not supported by the model.'): + with pytest.raises(UserError, match='Native structured output is not supported by the model.'): await agent.run('What is the largest city in the user country?') diff --git a/tests/test_logfire.py b/tests/test_logfire.py index b0beef859..3ea0229a3 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -289,6 +289,7 @@ async def my_ret(x: int) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ], 'output_mode': 'text', diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 1d8a0d9b0..ff3fab39a 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -613,18 +613,18 @@ def another_tool(y: int) -> int: timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), - RetryPromptPart( - tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", - timestamp=IsNow(tz=timezone.utc), - tool_call_id=IsStr(), - ), ToolReturnPart( tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), ToolReturnPart( tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool", + tool_name='unknown_tool', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), ] ), ] @@ -712,15 +712,15 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), part_kind='tool-return', ), ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), part_kind='tool-return', @@ -733,10 +733,7 @@ def another_tool(y: int) -> int: # pragma: no cover part_kind='tool-return', ), RetryPromptPart( - content='Unknown tool name: ' - "'unknown_tool'. Available tools: " - 'regular_tool, another_tool, ' - 'final_result', + content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), @@ -975,6 +972,13 @@ def known_tool(x: int) -> int: assert event_parts == snapshot( [ + FunctionToolCallEvent( + part=ToolCallPart( + tool_name='known_tool', + args={'x': 5}, + tool_call_id=IsStr(), + ) + ), FunctionToolCallEvent( part=ToolCallPart( tool_name='unknown_tool', @@ -991,9 +995,6 @@ def known_tool(x: int) -> int: ), tool_call_id=IsStr(), ), - FunctionToolCallEvent( - part=ToolCallPart(tool_name='known_tool', args={'x': 5}, tool_call_id=IsStr()), - ), FunctionToolResultEvent( result=ToolReturnPart( tool_name='known_tool', @@ -1003,13 +1004,6 @@ def known_tool(x: int) -> int: ), tool_call_id=IsStr(), ), - FunctionToolCallEvent( - part=ToolCallPart( - tool_name='unknown_tool', - args={'arg': 'value'}, - tool_call_id=IsStr(), - ), - ), ] ) @@ -1029,15 +1023,15 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf agent = Agent(FunctionModel(call_final_result_with_bad_data), output_type=OutputType) - event_parts: list[Any] = [] + events: list[Any] = [] async with agent.iter('test') as agent_run: async for node in agent_run: if Agent.is_call_tools_node(node): async with node.stream(agent_run.ctx) as event_stream: async for event in event_stream: - event_parts.append(event) + events.append(event) - assert event_parts == snapshot( + assert events == snapshot( [ FunctionToolCallEvent( part=ToolCallPart( @@ -1047,9 +1041,16 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf ), ), FunctionToolResultEvent( - result=ToolReturnPart( + result=RetryPromptPart( + content=[ + { + 'type': 'missing', + 'loc': ('value',), + 'msg': 'Field required', + 'input': {'bad_value': 'invalid'}, + } + ], tool_name='final_result', - content='Output tool not used - result failed validation.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), From 84cd9547a14ee9616f588dff45b8a5aa1301f3c2 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 27 Jun 2025 04:16:10 +0000 Subject: [PATCH 61/90] Stop double counting retries and reset on success --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 27 +--- pydantic_ai_slim/pydantic_ai/agent.py | 1 + pydantic_ai_slim/pydantic_ai/mcp.py | 4 +- pydantic_ai_slim/pydantic_ai/toolset.py | 141 +++++++++++-------- tests/models/test_model_test.py | 5 +- tests/test_examples.py | 4 +- 6 files changed, 98 insertions(+), 84 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 2588721ec..d48bdd3fe 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -12,7 +12,6 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast from opentelemetry.trace import Tracer -from pydantic import ValidationError from typing_extensions import TypeGuard, TypeVar, assert_never from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore @@ -610,7 +609,11 @@ async def process_function_tools( # noqa: C901 else: try: result_data = await _call_tool(toolset, call, run_context) + except exceptions.UnexpectedModelBehavior as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + raise e except _output.ToolRetryError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) yield _messages.FunctionToolCallEvent(call) parts.append(e.tool_retry) yield _messages.FunctionToolResultEvent(e.tool_retry, tool_call_id=call.tool_call_id) @@ -792,26 +795,8 @@ async def _call_tool( toolset: AbstractToolset[DepsT], tool_call: _messages.ToolCallPart, run_context: RunContext[DepsT] ) -> Any: run_context = dataclasses.replace(run_context, tool_call_id=tool_call.tool_call_id) - - try: - args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args) - response_content = await toolset.call_tool(run_context, tool_call.tool_name, args_dict) - except (ValidationError, exceptions.ModelRetry) as e: - if isinstance(e, ValidationError): - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=tool_call.tool_call_id, - ) - else: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=e.message, - tool_call_id=tool_call.tool_call_id, - ) - raise _output.ToolRetryError(m) - - return response_content + args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args) + return await toolset.call_tool(run_context, tool_call.tool_name, args_dict) async def _validate_output( diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 6c5643e7a..f89053980 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -364,6 +364,7 @@ def __init__( self._function_toolset = FunctionToolset[AgentDepsT](tools, max_retries=retries) # This will raise errors for any name conflicts + # TODO: Also include toolsets (not mcp_serves as we won't have tool defs yet) CombinedToolset[AgentDepsT]([self._output_toolset, self._function_toolset]) # TODO: Set max_retries on MCPServer diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 33c055563..8f980c246 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -182,14 +182,14 @@ async def list_tool_defs(self) -> list[ToolDefinition]: for mcp_tool in mcp_tools ] - def get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_core.SchemaValidator: + def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_core.SchemaValidator: return pydantic_core.SchemaValidator( schema=pydantic_core.core_schema.dict_schema( pydantic_core.core_schema.str_schema(), pydantic_core.core_schema.any_schema() ) ) - def max_retries_for_tool(self, name: str) -> int: + def _max_retries_for_tool(self, name: str) -> int: return 1 def set_mcp_sampling_model(self, model: models.Model) -> None: diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index 7fa274fec..464eef427 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -2,19 +2,20 @@ import asyncio from abc import ABC, abstractmethod -from collections.abc import Awaitable, Sequence -from contextlib import AsyncExitStack +from collections.abc import Awaitable, Iterator, Sequence +from contextlib import AsyncExitStack, contextmanager from dataclasses import dataclass, field, replace from functools import partial from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, assert_never, overload from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema from pydantic_core import SchemaValidator -from typing_extensions import Never, Self +from typing_extensions import Self -from ._output import BaseOutputSchema, OutputValidator +from . import messages as _messages +from ._output import BaseOutputSchema, OutputValidator, ToolRetryError from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError from .tools import ( @@ -70,21 +71,21 @@ def tool_names(self) -> list[str]: return [tool_def.name for tool_def in self.tool_defs] @abstractmethod - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: raise NotImplementedError() def validate_tool_args( self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False ) -> dict[str, Any]: pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - validator = self.get_tool_args_validator(ctx, name) + validator = self._get_tool_args_validator(ctx, name) if isinstance(args, str): return validator.validate_json(args or '{}', allow_partial=pyd_allow_partial) else: return validator.validate_python(args or {}, allow_partial=pyd_allow_partial) @abstractmethod - def max_retries_for_tool(self, name: str) -> int: + def _max_retries_for_tool(self, name: str) -> int: raise NotImplementedError() @abstractmethod @@ -273,10 +274,10 @@ async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDef def tool_defs(self) -> list[ToolDefinition]: return [tool.tool_def for tool in self.tools.values()] - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: return self.tools[name].function_schema.validator - def max_retries_for_tool(self, name: str) -> int: + def _max_retries_for_tool(self, name: str) -> int: tool = self.tools[name] return tool.max_retries if tool.max_retries is not None else self.max_retries @@ -298,10 +299,10 @@ class OutputToolset(AbstractToolset[AgentDepsT]): def tool_defs(self) -> list[ToolDefinition]: return [tool.tool_def for tool in self.output_schema.tools.values()] - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: return self.output_schema.tools[name].processor.validator - def max_retries_for_tool(self, name: str) -> int: + def _max_retries_for_tool(self, name: str) -> int: return self.max_retries async def call_tool( @@ -365,16 +366,16 @@ def tool_defs(self) -> list[ToolDefinition]: def tool_names(self) -> list[str]: return list(self._toolset_per_tool_name.keys()) - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self._toolset_for_tool_name(name).get_tool_args_validator(ctx, name) + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self._toolset_for_tool_name(name)._get_tool_args_validator(ctx, name) def validate_tool_args( self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False ) -> dict[str, Any]: return self._toolset_for_tool_name(name).validate_tool_args(ctx, name, args, allow_partial) - def max_retries_for_tool(self, name: str) -> int: - return self._toolset_for_tool_name(name).max_retries_for_tool(name) + def _max_retries_for_tool(self, name: str) -> int: + return self._toolset_for_tool_name(name)._max_retries_for_tool(name) async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any @@ -419,11 +420,11 @@ async def __aexit__( def tool_defs(self) -> list[ToolDefinition]: return self.wrapped.tool_defs - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self.wrapped.get_tool_args_validator(ctx, name) + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.wrapped._get_tool_args_validator(ctx, name) - def max_retries_for_tool(self, name: str) -> int: - return self.wrapped.max_retries_for_tool(name) + def _max_retries_for_tool(self, name: str) -> int: + return self.wrapped._max_retries_for_tool(name) async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any @@ -452,11 +453,11 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent def tool_defs(self) -> list[ToolDefinition]: return [replace(tool_def, name=self._prefixed_tool_name(tool_def.name)) for tool_def in super().tool_defs] - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return super().get_tool_args_validator(ctx, self._unprefixed_tool_name(name)) + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return super()._get_tool_args_validator(ctx, self._unprefixed_tool_name(name)) - def max_retries_for_tool(self, name: str) -> int: - return super().max_retries_for_tool(self._unprefixed_tool_name(name)) + def _max_retries_for_tool(self, name: str) -> int: + return super()._max_retries_for_tool(self._unprefixed_tool_name(name)) async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any @@ -519,11 +520,11 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent def tool_defs(self) -> list[ToolDefinition]: return self._tool_defs - def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return super().get_tool_args_validator(ctx, self._map_name(name)) + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return super()._get_tool_args_validator(ctx, self._map_name(name)) - def max_retries_for_tool(self, name: str) -> int: - return super().max_retries_for_tool(self._map_name(name)) + def _max_retries_for_tool(self, name: str) -> int: + return super()._max_retries_for_tool(self._map_name(name)) async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any @@ -660,40 +661,66 @@ def tool_names(self) -> list[str]: def validate_tool_args( self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False ) -> dict[str, Any]: - try: - self._validate_tool_name(name) - - ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0)) + with self._with_retry(name, ctx) as ctx: return super().validate_tool_args(ctx, name, args, allow_partial) - except ValidationError as e: - return self._on_error(name, e) async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any ) -> Any: + with self._with_retry(name, ctx) as ctx: + try: + output = await super().call_tool(ctx, name, tool_args, *args, **kwargs) + except Exception as e: + raise e + else: + self._retries.pop(name, None) + return output + + @contextmanager + def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunContext[AgentDepsT]]: try: - self._validate_tool_name(name) - - ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0)) - return await super().call_tool(ctx, name, tool_args, *args, **kwargs) - except ModelRetry as e: - return self._on_error(name, e) - - def _on_error(self, name: str, e: Exception) -> Never: - max_retries = self.max_retries_for_tool(name) - current_retry = self._retries.get(name, 0) - if current_retry == max_retries: - raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e - else: - self._retries[name] = current_retry + 1 # TODO: Reset on successful call! - raise e + if name not in self.tool_names: + if self.tool_names: + msg = f'Available tools: {", ".join(self.tool_names)}' + else: + msg = 'No tools available.' + raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') + + ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={}) + yield ctx + except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e: + if isinstance(e, ToolRetryError): + pass + elif isinstance(e, ValidationError): + if ctx.tool_call_id: + m = _messages.RetryPromptPart( + tool_name=name, + content=e.errors(include_url=False, include_context=False), + tool_call_id=ctx.tool_call_id, + ) + e = ToolRetryError(m) + elif isinstance(e, ModelRetry): + if ctx.tool_call_id: + m = _messages.RetryPromptPart( + tool_name=name, + content=e.message, + tool_call_id=ctx.tool_call_id, + ) + e = ToolRetryError(m) + elif isinstance(e, UnexpectedModelBehavior): + if e.__cause__ is not None: + e = e.__cause__ + else: + assert_never(e) - def _validate_tool_name(self, name: str) -> None: - if name in self.tool_names: - return + try: + max_retries = self._max_retries_for_tool(name) + except Exception: + max_retries = 1 + current_retry = self._retries.get(name, 0) - if self.tool_names: - msg = f'Available tools: {", ".join(self.tool_names)}' - else: - msg = 'No tools available.' - raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') + if current_retry == max_retries: + raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e + else: + self._retries[name] = current_retry + 1 + raise e diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 9d022641f..c43c55b89 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -4,6 +4,7 @@ import asyncio import dataclasses +import re from datetime import timezone from typing import Annotated, Any, Literal @@ -157,7 +158,7 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel: call_count += 1 raise ModelRetry('Fail') - with pytest.raises(UnexpectedModelBehavior, match="Tool 'final_result' exceeded max retries count of 2"): + with pytest.raises(UnexpectedModelBehavior, match=re.escape('Exceeded maximum retries (2) for result validation')): agent.run_sync('Hello', model=TestModel()) assert call_count == 3 @@ -200,7 +201,7 @@ class ResultModel(BaseModel): agent = Agent('test', output_type=ResultModel, retries=2) - with pytest.raises(UnexpectedModelBehavior, match="Tool 'final_result' exceeded max retries count of 2"): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for result validation'): agent.run_sync('Hello', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1})) diff --git a/tests/test_examples.py b/tests/test_examples.py index 0a223bb73..71c6cea51 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -270,10 +270,10 @@ async def __aexit__(self, *args: Any) -> None: def tool_defs(self) -> list[ToolDefinition]: return [] - def get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator: + def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator: return SchemaValidator(core_schema.any_schema()) # pragma: lax no cover - def max_retries_for_tool(self, name: str) -> int: + def _max_retries_for_tool(self, name: str) -> int: return 0 # pragma: lax no cover async def call_tool( From 74a56ae9b9f94b8ad35f9b1d9868bf891c8632b3 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 27 Jun 2025 13:50:26 +0000 Subject: [PATCH 62/90] Fix retry error wrapping --- docs/output.md | 4 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 4 +- pydantic_ai_slim/pydantic_ai/toolset.py | 45 +++++++++----------- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/docs/output.md b/docs/output.md index f32e403d7..825f2ec23 100644 --- a/docs/output.md +++ b/docs/output.md @@ -200,8 +200,8 @@ async def hand_off_to_sql_agent(ctx: RunContext, query: str) -> list[Row]: return output except UnexpectedModelBehavior as e: # Bubble up potentially retryable errors to the router agent - if (cause := e.__cause__) and hasattr(cause, 'tool_retry'): - raise ModelRetry(f'SQL agent failed: {cause.tool_retry.content}') from e + if (cause := e.__cause__) and isinstance(cause, ModelRetry): + raise ModelRetry(f'SQL agent failed: {cause.message}') from e else: raise diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index d48bdd3fe..a77c537e9 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -79,11 +79,13 @@ class GraphAgentState: retries: int run_step: int - def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None: + def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None: self.retries += 1 if self.retries > max_result_retries: message = f'Exceeded maximum retries ({max_result_retries}) for result validation' if error: + if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None: + error = error.__cause__ raise exceptions.UnexpectedModelBehavior(message) from error else: raise exceptions.UnexpectedModelBehavior(message) diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index 464eef427..1748e355d 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -7,7 +7,7 @@ from dataclasses import dataclass, field, replace from functools import partial from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, assert_never, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, overload from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema @@ -689,38 +689,33 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={}) yield ctx except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e: - if isinstance(e, ToolRetryError): - pass - elif isinstance(e, ValidationError): - if ctx.tool_call_id: - m = _messages.RetryPromptPart( - tool_name=name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=ctx.tool_call_id, - ) - e = ToolRetryError(m) - elif isinstance(e, ModelRetry): - if ctx.tool_call_id: - m = _messages.RetryPromptPart( - tool_name=name, - content=e.message, - tool_call_id=ctx.tool_call_id, - ) - e = ToolRetryError(m) - elif isinstance(e, UnexpectedModelBehavior): - if e.__cause__ is not None: - e = e.__cause__ - else: - assert_never(e) - try: max_retries = self._max_retries_for_tool(name) except Exception: max_retries = 1 current_retry = self._retries.get(name, 0) + if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None: + e = e.__cause__ + if current_retry == max_retries: raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e else: + if ctx.tool_call_id: + if isinstance(e, ValidationError): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.errors(include_url=False, include_context=False), + tool_call_id=ctx.tool_call_id, + ) + e = ToolRetryError(m) + elif isinstance(e, ModelRetry): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.message, + tool_call_id=ctx.tool_call_id, + ) + e = ToolRetryError(m) + self._retries[name] = current_retry + 1 raise e From 0360e7791308c01617a58c91781a976fe7db979d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 30 Jun 2025 15:42:28 +0000 Subject: [PATCH 63/90] Make DeferredToolCalls work with streaming --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 40 +++----- pydantic_ai_slim/pydantic_ai/_output.py | 12 +-- pydantic_ai_slim/pydantic_ai/agent.py | 18 ++-- pydantic_ai_slim/pydantic_ai/result.py | 35 +++++-- pydantic_ai_slim/pydantic_ai/toolset.py | 26 ++++- tests/test_streaming.py | 100 ++++++++++++++++++- 6 files changed, 178 insertions(+), 53 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index a77c537e9..e6cb27c26 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -21,7 +21,7 @@ from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage -from .output import DeferredToolCalls, OutputDataT, OutputSpec +from .output import OutputDataT, OutputSpec from .settings import ModelSettings, merge_model_settings from .tools import RunContext, ToolDefinition, ToolKind @@ -310,6 +310,7 @@ async def stream( ctx.deps.output_validators, build_run_context(ctx), ctx.deps.usage_limits, + ctx.deps.toolset, ) yield agent_stream # In case the user didn't manually consume the full stream, ensure it is fully consumed here, @@ -497,6 +498,13 @@ async def _handle_tool_calls( if final_result_holder: final_result = final_result_holder[0] self._next_node = self._handle_final_result(ctx, final_result, parts) + elif deferred_tool_calls := ctx.deps.toolset.get_deferred_tool_calls(tool_calls): + if not ctx.deps.output_schema.deferred_tool_calls: + raise exceptions.UserError( + 'There are deferred tool calls but DeferredToolCalls is not among output types.' + ) + final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None) + self._next_node = self._handle_final_result(ctx, final_result, parts) else: instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( @@ -578,16 +586,11 @@ async def process_function_tools( # noqa: C901 """ run_context = build_run_context(ctx) - unknown_calls: list[_messages.ToolCallPart] = [] - tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list) - # TODO: Make Toolset.tool_defs a dict - tool_defs_by_name: dict[str, ToolDefinition] = {tool_def.name: tool_def for tool_def in toolset.tool_defs} + tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list) for call in tool_calls: - try: - tool_def = tool_defs_by_name[call.tool_name] - tool_calls_by_kind[tool_def.kind].append(call) - except KeyError: - unknown_calls.append(call) + tool_def = toolset.get_tool_def(call.tool_name) + kind = tool_def.kind if tool_def else 'unknown' + tool_calls_by_kind[kind].append(call) # first, look for the output tool call for call in tool_calls_by_kind['output']: @@ -642,9 +645,9 @@ async def process_function_tools( # noqa: C901 else: calls_to_run.extend(tool_calls_by_kind['function']) - if unknown_calls: + if tool_calls_by_kind['unknown']: ctx.state.increment_retries(ctx.deps.max_result_retries) - calls_to_run.extend(unknown_calls) + calls_to_run.extend(tool_calls_by_kind['unknown']) for call in calls_to_run: yield _messages.FunctionToolCallEvent(call) @@ -718,7 +721,6 @@ def process_content(content: Any) -> Any: for k in sorted(results_by_index): parts.append(results_by_index[k]) - deferred_calls: list[_messages.ToolCallPart] = [] for call in tool_calls_by_kind['deferred']: if final_result: parts.append( @@ -730,18 +732,6 @@ def process_content(content: Any) -> Any: ) else: yield _messages.FunctionToolCallEvent(call) - deferred_calls.append(call) - - if deferred_calls: - if not ctx.deps.output_schema.deferred_tool_calls: - raise exceptions.UserError('There are pending tool calls but DeferredToolCalls is not among output types.') - - deferred_tool_names = [call.tool_name for call in deferred_calls] - deferred_tool_defs = { - tool_def.name: tool_def for tool_def in toolset.tool_defs if tool_def.name in deferred_tool_names - } - output_data = cast(NodeRunEndT, DeferredToolCalls(deferred_calls, deferred_tool_defs)) - final_result = result.FinalResult(output_data) parts.extend(user_parts) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index d53c7e801..fd9bc9c19 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -3,7 +3,7 @@ import inspect import json from abc import ABC, abstractmethod -from collections.abc import Awaitable, Iterable, Iterator, Sequence +from collections.abc import Awaitable, Iterable, Sequence from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload @@ -549,16 +549,6 @@ def find_named_tool( if part.tool_name == tool_name: return part, self.tools[tool_name] - def find_tool( - self, - parts: Iterable[_messages.ModelResponsePart], - ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: - """Find a tool that matches one of the calls.""" - for part in parts: - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if result := self.tools.get(part.tool_name): - yield part, result - @dataclass(init=False) class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchema[OutputDataT]): diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index f89053980..2b77b184e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1042,12 +1042,13 @@ async def stream_to_final( output_schema, _output.TextOutputSchema ): return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart) and isinstance( - output_schema, _output.ToolOutputSchema - ): # pragma: no branch - for call, _ in output_schema.find_tool([new_part]): - return FinalResult(s, call.tool_name, call.tool_call_id) - # TODO: Handle DeferredToolCalls + elif isinstance(new_part, _messages.ToolCallPart) and ( + tool_def := graph_ctx.deps.toolset.get_tool_def(new_part.tool_name) + ): + if tool_def.kind == 'output': + return FinalResult(s, new_part.tool_name, new_part.tool_call_id) + elif tool_def.kind == 'deferred': + return FinalResult(s, None, None) return None final_result = await stream_to_final(streamed_response) @@ -1072,16 +1073,20 @@ async def on_complete() -> None: # TODO: Should we move on to the CallToolsNode here, instead of doing this ourselves? parts: list[_messages.ModelRequestPart] = [] + # final_result_holder: list[result.FinalResult[models.StreamedResponse]] = [] async for _event in _agent_graph.process_function_tools( graph_ctx.deps.toolset, tool_calls, final_result, graph_ctx, parts, + # final_result_holder, ): pass if parts: messages.append(_messages.ModelRequest(parts)) + # if final_result_holder: + # final_result = final_result_holder[0] yield StreamedRunResult( messages, @@ -1093,6 +1098,7 @@ async def on_complete() -> None: graph_ctx.deps.output_validators, final_result.tool_name, on_complete, + graph_ctx.deps.toolset, ) break # TODO: There may be deferred tool calls, process those. diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 64da8b475..366392d34 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,11 +5,13 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import Generic +from typing import Generic, cast from pydantic import ValidationError from typing_extensions import TypeVar, deprecated, overload +from pydantic_ai.toolset import RunToolset + from . import _utils, exceptions, messages as _messages, models from ._output import ( OutputDataT_inv, @@ -47,6 +49,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None + _toolset: RunToolset[AgentDepsT] _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) _final_result_event: FinalResultEvent | None = field(default=None, init=False) @@ -102,6 +105,12 @@ async def _validate_response( result_data = await output_tool.process( call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + if not self._output_schema.deferred_tool_calls: + raise exceptions.UserError( + 'There are deferred tool calls but DeferredToolCalls is not among output types.' + ) + return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) @@ -113,8 +122,6 @@ async def _validate_response( 'Invalid response, unable to process text output' ) - # TODO: Possibly return DeferredToolCalls here? - for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) return result_data @@ -136,13 +143,19 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" if isinstance(e, _messages.PartStartEvent): new_part = e.part - if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema): - for call, _ in output_schema.find_tool([new_part]): # pragma: no branch - return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id) - elif isinstance(new_part, _messages.TextPart) and isinstance( + if isinstance(new_part, _messages.TextPart) and isinstance( output_schema, TextOutputSchema ): # pragma: no branch return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, _messages.ToolCallPart) and ( + tool_def := self._toolset.get_tool_def(new_part.tool_name) + ): + if tool_def.kind == 'output': + return _messages.FinalResultEvent( + tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id + ) + elif tool_def.kind == 'deferred': + return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) usage_checking_stream = _get_usage_checking_stream_response( self._raw_stream_response, self._usage_limits, self.usage @@ -177,6 +190,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] + _toolset: RunToolset[AgentDepsT] _initial_run_ctx_usage: Usage = field(init=False) is_complete: bool = field(default=False, init=False) @@ -382,7 +396,6 @@ async def get_output(self) -> OutputDataT: pass message = self._stream_response.get() await self._marked_completed(message) - # TODO: Possibly return DeferredToolCalls here? return await self.validate_structured_output(message) @deprecated('`get_data` is deprecated, use `get_output` instead.') @@ -423,6 +436,12 @@ async def validate_structured_output( result_data = await output_tool.process( call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + if not self._output_schema.deferred_tool_calls: + raise exceptions.UserError( + 'There are deferred tool calls but DeferredToolCalls is not among output types.' + ) + return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index 1748e355d..879eefd7a 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -2,7 +2,7 @@ import asyncio from abc import ABC, abstractmethod -from collections.abc import Awaitable, Iterator, Sequence +from collections.abc import Awaitable, Iterable, Iterator, Sequence from contextlib import AsyncExitStack, contextmanager from dataclasses import dataclass, field, replace from functools import partial @@ -14,6 +14,8 @@ from pydantic_core import SchemaValidator from typing_extensions import Self +from pydantic_ai.output import DeferredToolCalls + from . import messages as _messages from ._output import BaseOutputSchema, OutputValidator, ToolRetryError from ._run_context import AgentDepsT, RunContext @@ -70,6 +72,9 @@ def tool_defs(self) -> list[ToolDefinition]: def tool_names(self) -> list[str]: return [tool_def.name for tool_def in self.tool_defs] + def get_tool_def(self, name: str) -> ToolDefinition | None: + return next((tool_def for tool_def in self.tool_defs if tool_def.name == name), None) + @abstractmethod def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: raise NotImplementedError() @@ -676,6 +681,25 @@ async def call_tool( self._retries.pop(name, None) return output + def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None: + deferred_calls_and_defs = [ + (part, tool_def) + for part in parts + if isinstance(part, _messages.ToolCallPart) + and (tool_def := self.get_tool_def(part.tool_name)) + and tool_def.kind == 'deferred' + ] + if not deferred_calls_and_defs: + return None + + deferred_calls: list[_messages.ToolCallPart] = [] + deferred_tool_defs: dict[str, ToolDefinition] = {} + for part, tool_def in deferred_calls_and_defs: + deferred_calls.append(part) + deferred_tool_defs[part.tool_name] = tool_def + + return DeferredToolCalls(deferred_calls, deferred_tool_defs) + @contextmanager def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunContext[AgentDepsT]]: try: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index ff3fab39a..5d9a757b8 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,6 +5,7 @@ import re from collections.abc import AsyncIterator from copy import deepcopy +from dataclasses import replace from datetime import timezone from typing import Any, Union @@ -12,14 +13,16 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai import Agent, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( + FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent, ModelMessage, ModelRequest, ModelResponse, + PartStartEvent, RetryPromptPart, TextPart, ToolCallPart, @@ -28,8 +31,9 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import PromptedOutput, TextOutput +from pydantic_ai.output import DeferredToolCalls, PromptedOutput, TextOutput from pydantic_ai.result import AgentStream, FinalResult, Usage +from pydantic_ai.tools import ToolDefinition from pydantic_graph import End from .conftest import IsInt, IsNow, IsStr @@ -1081,3 +1085,95 @@ class CityLocation(BaseModel): ] ) assert result.is_complete + + +async def test_deferred_tool(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='deferred') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 + + async with agent.run_stream('Hello') as result: + assert not result.is_complete + output = await result.get_output() + assert output == snapshot( + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='deferred', + ) + }, + ) + ) + assert result.is_complete + + +async def test_deferred_tool_iter(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='deferred') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 + + outputs: list[str | DeferredToolCalls] = [] + events: list[Any] = [] + + async with agent.iter('test') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for event in stream: + events.append(event) + async for output in stream.stream_output(debounce_by=None): + outputs.append(output) + if agent.is_call_tools_node(node): + async with node.stream(run.ctx) as stream: + async for event in stream: + events.append(event) + + assert outputs == snapshot( + [ + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='deferred', + ) + }, + ) + ] + ) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr()), + ), + FinalResultEvent(tool_name=None, tool_call_id=None), + FunctionToolCallEvent(part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())), + ] + ) From 8a3febb9ff7acdc040a1058bdfef17a9b0564fae Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 30 Jun 2025 17:01:53 +0000 Subject: [PATCH 64/90] Let toolsets be overridden in run/iter/run_stream/run_sync --- pydantic_ai_slim/pydantic_ai/agent.py | 57 ++++++++++++++++++++------- tests/test_agent.py | 40 +++++++++++++++++++ 2 files changed, 83 insertions(+), 14 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b96e8c333..bb5f342cb 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -155,7 +155,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]): ) _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) _output_toolset: OutputToolset[AgentDepsT] = dataclasses.field(repr=False) + _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) + _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) _toolset: AbstractToolset[AgentDepsT] = dataclasses.field(repr=False) + _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False) _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False) @@ -179,7 +182,7 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), - toolsets: Sequence[AbstractToolset[AgentDepsT]] = (), + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -210,7 +213,7 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), - toolsets: Sequence[AbstractToolset[AgentDepsT]] = (), + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -238,7 +241,7 @@ def __init__( mcp_servers: Sequence[ MCPServer ] = (), # TODO: Deprecate argument, MCPServers can be passed directly to toolsets - toolsets: Sequence[AbstractToolset[AgentDepsT]] = (), + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -361,19 +364,18 @@ def __init__( self._system_prompt_dynamic_functions = {} self._max_result_retries = output_retries if output_retries is not None else retries + self._prepare_tools = prepare_tools - self._output_toolset = OutputToolset[AgentDepsT](self._output_schema, max_retries=self._max_result_retries) - self._function_toolset = FunctionToolset[AgentDepsT](tools, max_retries=retries) + self._output_toolset = OutputToolset(self._output_schema, max_retries=self._max_result_retries) + self._function_toolset = FunctionToolset(tools, max_retries=retries) + self._user_toolsets = toolsets or () + # TODO: Set max_retries on MCPServer + self._mcp_servers = mcp_servers # This will raise errors for any name conflicts - # TODO: Also include toolsets (not mcp_serves as we won't have tool defs yet) - CombinedToolset[AgentDepsT]([self._output_toolset, self._function_toolset]) - - # TODO: Set max_retries on MCPServer - toolset = CombinedToolset[AgentDepsT]([self._function_toolset, *toolsets, *mcp_servers]) - if prepare_tools: - toolset = PreparedToolset[AgentDepsT](toolset, prepare_tools) - self._toolset = toolset + self._toolset = CombinedToolset( + [self._output_toolset, self._function_toolset, *self._user_toolsets, *self._mcp_servers] + ) self.history_processors = history_processors or [] @@ -395,6 +397,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -410,6 +413,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -426,6 +430,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -440,6 +445,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -470,6 +476,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. Returns: The result of the run. @@ -494,6 +501,7 @@ async def main(): model_settings=model_settings, usage_limits=usage_limits, usage=usage, + toolsets=toolsets, ) as agent_run: async for _ in agent_run: pass @@ -514,6 +522,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -530,6 +539,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -547,6 +557,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager @@ -562,6 +573,7 @@ async def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -636,6 +648,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. Returns: The result of the run. @@ -693,7 +706,11 @@ async def main(): run_step=state.run_step, ) - toolset = CombinedToolset([output_toolset, self._toolset]) + user_toolsets = self._user_toolsets if toolsets is None else toolsets + toolset = CombinedToolset([self._function_toolset, *user_toolsets, *self._mcp_servers]) + if self._prepare_tools: + toolset = PreparedToolset(toolset, self._prepare_tools) + toolset = CombinedToolset([output_toolset, toolset]) run_toolset = await toolset.prepare_for_run(run_context) model_settings = merge_model_settings(self.model_settings, model_settings) @@ -814,6 +831,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -829,6 +847,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -845,6 +864,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -859,6 +879,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -888,6 +909,7 @@ def run_sync( usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. Returns: The result of the run. @@ -914,6 +936,7 @@ def run_sync( usage_limits=usage_limits, usage=usage, infer_name=False, + toolsets=toolsets, ) ) @@ -929,6 +952,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -944,6 +968,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @overload @@ -960,6 +985,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -975,6 +1001,7 @@ async def run_stream( # noqa C901 usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -1002,6 +1029,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. Returns: The result of the run. @@ -1032,6 +1060,7 @@ async def main(): usage_limits=usage_limits, usage=usage, infer_name=False, + toolsets=toolsets, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node diff --git a/tests/test_agent.py b/tests/test_agent.py index 517449d0f..97654872d 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -45,6 +45,7 @@ from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolset import FunctionToolset from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -3451,3 +3452,42 @@ def test_deprecated_kwargs_mixed_valid_invalid(): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) # Ignore the deprecation warning for result_tool_name Agent('test', result_tool_name='test', foo='value1', bar='value2') # type: ignore[call-arg] + + +def test_override_toolsets(): + foo_toolset = FunctionToolset() + + @foo_toolset.tool + def foo() -> str: + return 'Hello from foo' + + bar_toolset = FunctionToolset() + + @bar_toolset.tool + def bar() -> str: + return 'Hello from bar' + + available_tools: list[list[str]] = [] + + async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + nonlocal available_tools + available_tools.append([tool_def.name for tool_def in tool_defs]) + return tool_defs + + agent = Agent('test', toolsets=[foo_toolset], prepare_tools=prepare_tools) + + @agent.tool_plain + def baz() -> str: + return 'Hello from baz' + + result = agent.run_sync('Hello') + assert available_tools[-1] == snapshot(['baz', 'foo']) + assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo"}') + + result = agent.run_sync('Hello', toolsets=[bar_toolset]) + assert available_tools[-1] == snapshot(['baz', 'bar']) + assert result.output == snapshot('{"baz":"Hello from baz","bar":"Hello from bar"}') + + result = agent.run_sync('Hello', toolsets=[]) + assert available_tools[-1] == snapshot(['baz']) + assert result.output == snapshot('{"baz":"Hello from baz"}') From 2e200acb57e6e13b533cce52386f62c95c2fc2ae Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 30 Jun 2025 17:53:48 +0000 Subject: [PATCH 65/90] Add DeferredToolset --- pydantic_ai_slim/pydantic_ai/toolset.py | 24 ++++++++++++++++++++++++ tests/test_tools.py | 22 ++++++++++++---------- 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index 879eefd7a..b87bc5b50 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -319,6 +319,30 @@ async def call_tool( return output +class DeferredToolset(AbstractToolset[AgentDepsT]): + """A toolset that holds deferred tool.""" + + _tool_defs: list[ToolDefinition] + + def __init__(self, tool_defs: list[ToolDefinition]): + self._tool_defs = tool_defs + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [replace(tool_def, kind='deferred') for tool_def in self._tool_defs] + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + raise NotImplementedError('Deferred tools cannot be validated') + + def _max_retries_for_tool(self, name: str) -> int: + raise NotImplementedError('Deferred tools cannot be retried') + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + raise NotImplementedError('Deferred tools cannot be called') + + @dataclass(init=False) class CombinedToolset(AbstractToolset[AgentDepsT]): """A toolset that combines multiple toolsets.""" diff --git a/tests/test_tools.py b/tests/test_tools.py index 9a4a6bf7f..09a7001c7 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -18,6 +18,7 @@ from pydantic_ai.models.test import TestModel from pydantic_ai.output import DeferredToolCalls, ToolOutput from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolset import DeferredToolset from .conftest import IsStr @@ -1181,14 +1182,16 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int: def test_deferred_tool(): - agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) - - async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: - return replace(tool_def, kind='deferred') - - @agent.tool_plain(prepare=prepare_tool) - def my_tool(x: int) -> int: - return x + 1 + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls], toolsets=[deferred_toolset]) result = agent.run_sync('Hello') assert result.output == snapshot( @@ -1199,10 +1202,9 @@ def my_tool(x: int) -> int: name='my_tool', description='', parameters_json_schema={ - 'additionalProperties': False, + 'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x'], - 'type': 'object', }, kind='deferred', ) From 1cb7f324b499240315b70de2de61849af4f00a28 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 30 Jun 2025 18:02:09 +0000 Subject: [PATCH 66/90] Add LangChainToolset --- pydantic_ai_slim/pydantic_ai/ext/langchain.py | 10 ++++- pydantic_ai_slim/pydantic_ai/toolset.py | 6 +-- tests/ext/test_langchain.py | 45 +++++++++++-------- tests/test_tools.py | 2 +- 4 files changed, 38 insertions(+), 25 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 9d13adda0..1512ee4f7 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -3,6 +3,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai.tools import Tool +from pydantic_ai.toolset import FunctionToolset class LangChainTool(Protocol): @@ -23,7 +24,7 @@ def description(self) -> str: ... def run(self, *args: Any, **kwargs: Any) -> str: ... -__all__ = ('tool_from_langchain',) +__all__ = ('tool_from_langchain', 'LangChainToolset') def tool_from_langchain(langchain_tool: LangChainTool) -> Tool: @@ -59,3 +60,10 @@ def proxy(*args: Any, **kwargs: Any) -> str: description=function_description, json_schema=schema, ) + + +class LangChainToolset(FunctionToolset): + """A toolset that wraps LangChain tools.""" + + def __init__(self, tools: list[LangChainTool]): + super().__init__([tool_from_langchain(tool) for tool in tools]) diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index b87bc5b50..fbbcbe618 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -46,7 +46,7 @@ class AbstractToolset(ABC, Generic[AgentDepsT]): @property def name(self) -> str: - return self.__class__.__name__ + return self.__class__.__name__.replace('Toolset', ' toolset') @property def name_conflict_hint(self) -> str: @@ -110,10 +110,6 @@ class FunctionToolset(AbstractToolset[AgentDepsT]): max_retries: int = field(default=1) tools: dict[str, Tool[Any]] = field(default_factory=dict) - @property - def name(self) -> str: - return 'FunctionToolset' - def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): self.max_retries = max_retries self.tools = {} diff --git a/tests/ext/test_langchain.py b/tests/ext/test_langchain.py index 73e7cc050..926a22819 100644 --- a/tests/ext/test_langchain.py +++ b/tests/ext/test_langchain.py @@ -6,7 +6,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai import Agent -from pydantic_ai.ext.langchain import tool_from_langchain +from pydantic_ai.ext.langchain import LangChainToolset, tool_from_langchain @dataclass @@ -49,24 +49,26 @@ def get_input_jsonschema(self) -> JsonSchemaValue: } -def test_langchain_tool_conversion(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, +langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', }, - ) + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, +) + + +def test_langchain_tool_conversion(): pydantic_tool = tool_from_langchain(langchain_tool) agent = Agent('test', tools=[pydantic_tool], retries=7) @@ -74,6 +76,13 @@ def test_langchain_tool_conversion(): assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") +def test_langchain_toolset(): + toolset = LangChainToolset([langchain_tool]) + agent = Agent('test', toolsets=[toolset], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") + + def test_langchain_tool_no_additional_properties(): langchain_tool = SimulatedLangChainTool( name='file_search', diff --git a/tests/test_tools.py b/tests/test_tools.py index 09a7001c7..721e3941a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -589,7 +589,7 @@ def test_tool_return_conflict(): # this raises an error with pytest.raises( UserError, - match="FunctionToolset defines a tool whose name conflicts with existing tool from OutputToolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.", + match="Function toolset defines a tool whose name conflicts with existing tool from OutputToolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.", ): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) From a6eba43a3a22546223713ef7fb28cf944ee75800 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 30 Jun 2025 20:02:28 +0000 Subject: [PATCH 67/90] Add Agent.prepare_output_tools --- pydantic_ai_slim/pydantic_ai/agent.py | 12 +++- tests/test_agent.py | 92 +++++++++++++++++++++++++++ tests/test_tools.py | 2 +- 3 files changed, 104 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index bb5f342cb..52475ebcc 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -159,6 +159,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) _toolset: AbstractToolset[AgentDepsT] = dataclasses.field(repr=False) _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) + _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False) _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False) @@ -181,6 +182,7 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, @@ -212,6 +214,7 @@ def __init__( result_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, @@ -238,6 +241,7 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[ MCPServer ] = (), # TODO: Deprecate argument, MCPServers can be passed directly to toolsets @@ -270,9 +274,12 @@ def __init__( output_retries: The maximum number of retries to allow for result validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. - prepare_tools: custom method to prepare the tool definition of all tools for each step. + prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. This is useful if you want to customize the definition of multiple tools or you want to register a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] + prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. + This is useful if you want to customize the definition of multiple output tools or you want to register + a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer] for each server you want the agent to connect to. toolsets: Toolsets to register with the agent. @@ -365,6 +372,7 @@ def __init__( self._max_result_retries = output_retries if output_retries is not None else retries self._prepare_tools = prepare_tools + self._prepare_output_tools = prepare_output_tools self._output_toolset = OutputToolset(self._output_schema, max_retries=self._max_result_retries) self._function_toolset = FunctionToolset(tools, max_retries=retries) @@ -682,6 +690,8 @@ async def main(): output_toolset = OutputToolset[AgentDepsT]( output_schema, max_retries=self._max_result_retries, output_validators=output_validators ) + if self._prepare_output_tools: + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( diff --git a/tests/test_agent.py b/tests/test_agent.py index 97654872d..c58fc3948 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,6 +1,7 @@ import json import re import sys +from dataclasses import dataclass from datetime import timezone from typing import Any, Callable, Union @@ -3491,3 +3492,94 @@ def baz() -> str: result = agent.run_sync('Hello', toolsets=[]) assert available_tools[-1] == snapshot(['baz']) assert result.output == snapshot('{"baz":"Hello from baz"}') + + +def test_prepare_output_tools(): + @dataclass + class AgentDeps: + plan_presented: bool = False + + async def present_plan(ctx: RunContext[AgentDeps], plan: str) -> str: + """ + Present the plan to the user. + """ + ctx.deps.plan_presented = True + return plan + + async def run_sql(ctx: RunContext[AgentDeps], purpose: str, query: str) -> str: + """ + Run an SQL query. + """ + return 'SQL query executed successfully' + + async def only_if_plan_presented( + ctx: RunContext[AgentDeps], tool_defs: list[ToolDefinition] + ) -> list[ToolDefinition]: + return tool_defs if ctx.deps.plan_presented else [] + + agent = Agent( + model='test', + deps_type=AgentDeps, + tools=[present_plan], + output_type=[ToolOutput(run_sql, name='run_sql')], + prepare_output_tools=only_if_plan_presented, + ) + + result = agent.run_sync('Hello', deps=AgentDeps()) + assert result.output == snapshot('SQL query executed successfully') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='present_plan', + args={'plan': 'a'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + model_name='test', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='present_plan', + content='a', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='run_sql', + args={'purpose': 'a', 'query': 'a'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=52, response_tokens=12, total_tokens=64), + model_name='test', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='run_sql', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index 721e3941a..3d6ab8d6a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -589,7 +589,7 @@ def test_tool_return_conflict(): # this raises an error with pytest.raises( UserError, - match="Function toolset defines a tool whose name conflicts with existing tool from OutputToolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.", + match="Function toolset defines a tool whose name conflicts with existing tool from Output toolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.", ): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) From 0c9612611380dd6684f32e45509f64771086752f Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 1 Jul 2025 18:08:21 +0000 Subject: [PATCH 68/90] Require WrapperToolset subclasses to implement their own prepare_for_run --- pydantic_ai_slim/pydantic_ai/toolset.py | 4 ++++ tests/test_mcp.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py index fbbcbe618..62efd7f7d 100644 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ b/pydantic_ai_slim/pydantic_ai/toolset.py @@ -441,6 +441,10 @@ async def __aexit__( ) -> bool | None: return await self.wrapped.__aexit__(exc_type, exc_value, traceback) + @abstractmethod + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + raise NotImplementedError() + @property def tool_defs(self) -> list[ToolDefinition]: return self.wrapped.tool_defs diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 83c4d33b7..41230f0aa 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -232,7 +232,7 @@ def get_none() -> None: # pragma: no cover with pytest.raises( UserError, match=re.escape( - "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from FunctionToolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from Function toolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." ), ): await agent.run('Get me a conflict') From 2348f45eaf4f71a806db782559d48e44c46d2e8b Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 1 Jul 2025 18:50:17 +0000 Subject: [PATCH 69/90] Require DeferredToolCalls to be used with other output type --- pydantic_ai_slim/pydantic_ai/_output.py | 6 +++++ tests/test_tools.py | 29 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 1409cd3f1..4debde975 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -190,6 +190,12 @@ def build( # noqa: C901 outputs = [output for output in raw_outputs if output is not DeferredToolCalls] deferred_tool_calls = len(outputs) < len(raw_outputs) + if len(outputs) == 0: + if deferred_tool_calls: + raise UserError('At least one output type must be provided other than DeferredToolCalls.') + else: + raise UserError('At least one output type must be provided.') + if output := next((output for output in outputs if isinstance(output, NativeOutput)), None): if len(outputs) > 1: raise UserError('NativeOutput cannot be mixed with other output types.') diff --git a/tests/test_tools.py b/tests/test_tools.py index 3d6ab8d6a..8ce580bf3 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1211,3 +1211,32 @@ def test_deferred_tool(): }, ) ) + + +def test_deferred_tool_with_output_type(): + class MyModel(BaseModel): + foo: str + + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(call_tools=[]), output_type=[MyModel, DeferredToolCalls], toolsets=[deferred_toolset]) + + result = agent.run_sync('Hello') + assert result.output == snapshot(MyModel(foo='a')) + + +def test_output_type_deferred_tool_calls_by_itself(): + with pytest.raises(UserError, match='At least one output type must be provided other than DeferredToolCalls.'): + Agent(TestModel(), output_type=DeferredToolCalls) + + +def test_output_type_empty(): + with pytest.raises(UserError, match='At least one output type must be provided.'): + Agent(TestModel(), output_type=[]) From f3124c080925f14959ae19f06b5af8e6fbc9c19d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 1 Jul 2025 23:53:41 +0000 Subject: [PATCH 70/90] Lots of cleanup --- docs/mcp/client.md | 14 +- mcp-run-python/README.md | 2 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 205 ++--- pydantic_ai_slim/pydantic_ai/_output.py | 259 +++--- pydantic_ai_slim/pydantic_ai/agent.py | 102 ++- pydantic_ai_slim/pydantic_ai/exceptions.py | 12 + pydantic_ai_slim/pydantic_ai/ext/langchain.py | 2 +- pydantic_ai_slim/pydantic_ai/mcp.py | 25 +- pydantic_ai_slim/pydantic_ai/output.py | 2 +- pydantic_ai_slim/pydantic_ai/result.py | 66 +- pydantic_ai_slim/pydantic_ai/tools.py | 13 +- pydantic_ai_slim/pydantic_ai/toolset.py | 769 ------------------ .../pydantic_ai/toolsets/__init__.py | 84 ++ .../pydantic_ai/toolsets/combined.py | 99 +++ .../pydantic_ai/toolsets/deferred.py | 38 + .../pydantic_ai/toolsets/filtered.py | 24 + .../pydantic_ai/toolsets/function.py | 208 +++++ .../toolsets/individually_prepared.py | 44 + .../pydantic_ai/toolsets/mapped.py | 55 ++ .../pydantic_ai/toolsets/prefixed.py | 47 ++ .../pydantic_ai/toolsets/prepared.py | 29 + .../pydantic_ai/toolsets/processed.py | 44 + pydantic_ai_slim/pydantic_ai/toolsets/run.py | 142 ++++ .../pydantic_ai/toolsets/wrapper.py | 66 ++ tests/models/test_model_test.py | 4 +- tests/test_agent.py | 6 +- tests/test_examples.py | 6 +- tests/test_mcp.py | 10 +- tests/test_streaming.py | 4 +- tests/test_tools.py | 4 +- tests/test_toolset.py | 2 +- 31 files changed, 1281 insertions(+), 1106 deletions(-) delete mode 100644 pydantic_ai_slim/pydantic_ai/toolset.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/__init__.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/combined.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/deferred.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/filtered.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/function.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/mapped.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/prepared.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/processed.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/run.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 336e47045..17de670d6 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -47,7 +47,7 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE(url='http://localhost:3001/sse') # (1)! -agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! +agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): @@ -118,7 +118,7 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! -agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! +agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): async with agent.run_toolsets(): # (3)! @@ -156,7 +156,7 @@ server = MCPServerStdio( # (1)! 'stdio', ] ) -agent = Agent('openai:gpt-4o', mcp_servers=[server]) +agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): @@ -200,7 +200,7 @@ server = MCPServerStdio('python', ['mcp_server.py'], process_tool_call=process_t agent = Agent( model=TestModel(call_tools=['echo_deps']), deps_type=int, - mcp_servers=[server] + toolsets=[server] ) @@ -243,7 +243,7 @@ calculator_server = MCPServerSSE( # Both servers might have a tool named 'get_data', but they'll be exposed as: # - 'weather_get_data' # - 'calc_get_data' -agent = Agent('openai:gpt-4o', mcp_servers=[weather_server, calculator_server]) +agent = Agent('openai:gpt-4o', toolsets=[weather_server, calculator_server]) ``` ### Example with Stdio Server @@ -273,7 +273,7 @@ js_server = MCPServerStdio( tool_prefix='js' # Tools will be prefixed with 'js_' ) -agent = Agent('openai:gpt-4o', mcp_servers=[python_server, js_server]) +agent = Agent('openai:gpt-4o', toolsets=[python_server, js_server]) ``` When the model interacts with these servers, it will see the prefixed tool names, but the prefixes will be automatically handled when making tool calls. @@ -360,7 +360,7 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio(command='python', args=['generate_svg.py']) -agent = Agent('openai:gpt-4o', mcp_servers=[server]) +agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): diff --git a/mcp-run-python/README.md b/mcp-run-python/README.md index 0d57fb762..93bbfc87f 100644 --- a/mcp-run-python/README.md +++ b/mcp-run-python/README.md @@ -52,7 +52,7 @@ server = MCPServerStdio('deno', 'jsr:@pydantic/mcp-run-python', 'stdio', ]) -agent = Agent('claude-3-5-haiku-latest', mcp_servers=[server]) +agent = Agent('claude-3-5-haiku-latest', toolsets=[server]) async def main(): diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index cdc8b4dab..0439a1f2e 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -4,7 +4,7 @@ import dataclasses import hashlib import json -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -16,11 +16,13 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._utils import is_async_callable, run_in_executor -from pydantic_ai.toolset import AbstractToolset, RunToolset +from pydantic_ai.toolsets import AbstractToolset +from pydantic_ai.toolsets.run import RunToolset from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage +from .exceptions import ToolRetryError from .output import OutputDataT, OutputSpec from .settings import ModelSettings, merge_model_settings from .tools import RunContext, ToolDefinition, ToolKind @@ -82,7 +84,7 @@ class GraphAgentState: def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None: self.retries += 1 if self.retries > max_result_retries: - message = f'Exceeded maximum retries ({max_result_retries}) for result validation' + message = f'Exceeded maximum retries ({max_result_retries}) for output validation' if error: if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None: error = error.__cause__ @@ -489,26 +491,28 @@ async def _handle_tool_calls( ) -> AsyncIterator[_messages.HandleResponseEvent]: run_context = build_run_context(ctx) - parts: list[_messages.ModelRequestPart] = [] - final_result_holder: list[result.FinalResult[NodeRunEndT]] = [] + output_parts: list[_messages.ModelRequestPart] = [] + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) - async for event in process_function_tools(ctx.deps.toolset, tool_calls, None, ctx, parts, final_result_holder): + async for event in process_function_tools( + ctx.deps.toolset, tool_calls, None, ctx, output_parts, output_final_result + ): yield event - if final_result_holder: - final_result = final_result_holder[0] - self._next_node = self._handle_final_result(ctx, final_result, parts) + if output_final_result: + final_result = output_final_result[0] + self._next_node = self._handle_final_result(ctx, final_result, output_parts) elif deferred_tool_calls := ctx.deps.toolset.get_deferred_tool_calls(tool_calls): - if not ctx.deps.output_schema.deferred_tool_calls: + if not ctx.deps.output_schema.allows_deferred_tool_calls: raise exceptions.UserError( 'There are deferred tool calls but DeferredToolCalls is not among output types.' ) final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None) - self._next_node = self._handle_final_result(ctx, final_result, parts) + self._next_node = self._handle_final_result(ctx, final_result, output_parts) else: instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest(parts=parts, instructions=instructions) + _messages.ModelRequest(parts=output_parts, instructions=instructions) ) def _handle_final_result( @@ -541,10 +545,10 @@ async def _handle_text_response( m = _messages.RetryPromptPart( content='Plain text responses are not permitted, please include your response in a tool call', ) - raise _output.ToolRetryError(m) + raise ToolRetryError(m) result_data = await _validate_output(result_data, ctx, None) - except _output.ToolRetryError as e: + except ToolRetryError as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: @@ -575,8 +579,8 @@ async def process_function_tools( # noqa: C901 tool_calls: list[_messages.ToolCallPart], final_result: result.FinalResult[NodeRunEndT] | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], - parts: list[_messages.ModelRequestPart], - final_result_holder: list[result.FinalResult[NodeRunEndT]] = [], + output_parts: list[_messages.ModelRequestPart], + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1), ) -> AsyncIterator[_messages.HandleResponseEvent]: """Process function (i.e., non-result) tool calls in parallel. @@ -592,7 +596,7 @@ async def process_function_tools( # noqa: C901 kind = tool_def.kind if tool_def else 'unknown' tool_calls_by_kind[kind].append(call) - # first, look for the output tool call + # First, we handle output tool calls for call in tool_calls_by_kind['output']: if final_result: if final_result.tool_call_id == call.tool_call_id: @@ -610,17 +614,17 @@ async def process_function_tools( # noqa: C901 ) yield _messages.FunctionToolResultEvent(part) - parts.append(part) + output_parts.append(part) else: try: result_data = await _call_tool(toolset, call, run_context) except exceptions.UnexpectedModelBehavior as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) raise e - except _output.ToolRetryError as e: + except ToolRetryError as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) yield _messages.FunctionToolCallEvent(call) - parts.append(e.tool_retry) + output_parts.append(e.tool_retry) yield _messages.FunctionToolResultEvent(e.tool_retry) else: part = _messages.ToolReturnPart( @@ -628,14 +632,14 @@ async def process_function_tools( # noqa: C901 content='Final result processed.', tool_call_id=call.tool_call_id, ) - parts.append(part) + output_parts.append(part) final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) + # Then, we handle function tool calls calls_to_run: list[_messages.ToolCallPart] = [] - # Then build the other request parts based on end strategy if final_result and ctx.deps.end_strategy == 'early': for call in tool_calls_by_kind['function']: - parts.append( + output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, content='Tool not executed - a final result was already processed.', @@ -645,6 +649,7 @@ async def process_function_tools( # noqa: C901 else: calls_to_run.extend(tool_calls_by_kind['function']) + # Then, we handle unknown tool calls if tool_calls_by_kind['unknown']: ctx.state.increment_retries(ctx.deps.max_result_retries) calls_to_run.extend(tool_calls_by_kind['unknown']) @@ -660,7 +665,7 @@ async def process_function_tools( # noqa: C901 ) # Run all tool tasks in parallel - results_by_index: dict[int, _messages.ModelRequestPart] = {} + parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {} with ctx.deps.tracer.start_as_current_span( 'running tools', attributes={ @@ -681,78 +686,20 @@ async def process_function_tools( # noqa: C901 done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: index = tasks.index(task) - tool_result = task.result() - yield _messages.FunctionToolResultEvent(tool_result) - - if isinstance(tool_result, _messages.RetryPromptPart): - results_by_index[index] = tool_result - elif isinstance(tool_result, _messages.ToolReturnPart): - if isinstance(tool_result.content, _messages.ToolReturn): - tool_return = tool_result.content - if ( - isinstance(tool_return.return_value, _messages.MultiModalContentTypes) - or isinstance(tool_return.return_value, list) - and any( - isinstance(content, _messages.MultiModalContentTypes) - for content in tool_return.return_value # type: ignore - ) - ): - raise exceptions.UserError( - f"{tool_result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. " - f'Please use `content` instead.' - ) - tool_result.content = tool_return.return_value # type: ignore - tool_result.metadata = tool_return.metadata - if tool_return.content: - user_parts.append( - _messages.UserPromptPart( - content=list(tool_return.content), - timestamp=tool_result.timestamp, - part_kind='user-prompt', - ) - ) - - def process_content(content: Any) -> Any: - if isinstance(content, _messages.ToolReturn): - raise exceptions.UserError( - f"{tool_result.tool_name}'s return contains invalid nested ToolReturn objects. " - f'ToolReturn should be used directly.' - ) - elif isinstance(content, _messages.MultiModalContentTypes): - if isinstance(content, _messages.BinaryContent): - identifier = multi_modal_content_identifier(content.data) - else: - identifier = multi_modal_content_identifier(content.url) - - user_parts.append( - _messages.UserPromptPart( - content=[f'This is file {identifier}:', content], - timestamp=tool_result.timestamp, - part_kind='user-prompt', - ) - ) - return f'See file {identifier}' - else: - return content + tool_result_part, extra_parts = task.result() + yield _messages.FunctionToolResultEvent(tool_result_part) - if isinstance(tool_result.content, list): - contents = cast(list[Any], tool_result.content) # type: ignore - tool_result.content = [process_content(content) for content in contents] - else: - tool_result.content = process_content(tool_result.content) - - results_by_index[index] = tool_result - else: - assert_never(tool_result) + parts_by_index[index] = [tool_result_part, *extra_parts] # We append the results at the end, rather than as they are received, to retain a consistent ordering # This is mostly just to simplify testing - for k in sorted(results_by_index): - parts.append(results_by_index[k]) + for k in sorted(parts_by_index): + output_parts.extend(parts_by_index[k]) + # Finally, we handle deferred tool calls for call in tool_calls_by_kind['deferred']: if final_result: - parts.append( + output_parts.append( _messages.ToolReturnPart( tool_name=call.tool_name, content='Tool not executed - a final result was already processed.', @@ -762,11 +709,10 @@ def process_content(content: Any) -> Any: else: yield _messages.FunctionToolCallEvent(call) - parts.extend(user_parts) + output_parts.extend(user_parts) if final_result: - # TODO: Use some better "box" object - final_result_holder.append(final_result) + output_final_result.append(final_result) async def _call_function_tool( @@ -775,7 +721,7 @@ async def _call_function_tool( run_context: RunContext[DepsT], tracer: Tracer, include_content: bool = False, -) -> _messages.ToolReturnPart | _messages.RetryPromptPart: +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]: """Run the tool function asynchronously. See . @@ -801,15 +747,72 @@ async def _call_function_tool( with tracer.start_as_current_span('running tool', attributes=span_attributes): try: - response_content = await _call_tool(toolset, tool_call, run_context) - except _output.ToolRetryError as e: - return e.tool_retry + tool_result = await _call_tool(toolset, tool_call, run_context) + except ToolRetryError as e: + return (e.tool_retry, []) + + extra_parts: list[_messages.ModelRequestPart] = [] + metadata = None + + def process_content(content: Any) -> Any: + if isinstance(content, _messages.ToolReturn): + raise exceptions.UserError( + f"{tool_call.tool_name}'s return contains invalid nested ToolReturn objects. " + f'ToolReturn should be used directly.' + ) + elif isinstance(content, _messages.MultiModalContentTypes): + if isinstance(content, _messages.BinaryContent): + identifier = multi_modal_content_identifier(content.data) + else: + identifier = multi_modal_content_identifier(content.url) + + extra_parts.append( + _messages.UserPromptPart( + content=[f'This is file {identifier}:', content], + part_kind='user-prompt', + ) + ) + return f'See file {identifier}' + else: + return content + + if isinstance(tool_result, _messages.ToolReturn): + if ( + isinstance(tool_result.return_value, _messages.MultiModalContentTypes) + or isinstance(tool_result.return_value, list) + and any( + isinstance(content, _messages.MultiModalContentTypes) + for content in tool_result.return_value # type: ignore + ) + ): + raise exceptions.UserError( + f"{tool_call.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. " + f'Please use `content` instead.' + ) + + metadata = tool_result.metadata + if tool_result.content: + extra_parts.append( + _messages.UserPromptPart( + content=list(tool_result.content), + part_kind='user-prompt', + ) + ) + tool_result = tool_result.return_value # type: ignore + elif isinstance(tool_result, list): + contents = cast(list[Any], tool_result) + tool_result = [process_content(content) for content in contents] else: - return _messages.ToolReturnPart( - tool_name=tool_call.tool_name, - content=response_content, - tool_call_id=tool_call.tool_call_id, - ) + tool_result = process_content(tool_result) + + part = _messages.ToolReturnPart( + tool_name=tool_call.tool_name, + content=tool_result, + metadata=metadata, + tool_call_id=tool_call.tool_call_id, + ) + + return (part, extra_parts) async def _call_tool( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 4debde975..6a22f76e8 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -3,7 +3,7 @@ import inspect import json from abc import ABC, abstractmethod -from collections.abc import Awaitable, Iterable, Sequence +from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload @@ -13,7 +13,7 @@ from . import _function_schema, _utils, messages as _messages from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UserError +from .exceptions import ModelRetry, ToolRetryError, UserError from .output import ( DeferredToolCalls, NativeOutput, @@ -28,6 +28,8 @@ ToolOutput, ) from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition +from .toolsets import AbstractToolset +from .toolsets.run import RunToolset if TYPE_CHECKING: from .profiles import ModelProfile @@ -67,14 +69,6 @@ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' -class ToolRetryError(Exception): - """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" - - def __init__(self, tool_retry: _messages.RetryPromptPart): - self.tool_retry = tool_retry - super().__init__() - - @dataclass class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]): function: OutputValidatorFunc[AgentDepsT, OutputDataT_inv] @@ -135,16 +129,16 @@ async def validate( @dataclass class BaseOutputSchema(ABC, Generic[OutputDataT]): - deferred_tool_calls: bool + allows_deferred_tool_calls: bool @abstractmethod def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: raise NotImplementedError() @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - return {} + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + return None @dataclass(init=False) @@ -189,16 +183,13 @@ def build( # noqa: C901 raw_outputs = _flatten_output_spec(output_spec) outputs = [output for output in raw_outputs if output is not DeferredToolCalls] - deferred_tool_calls = len(outputs) < len(raw_outputs) - if len(outputs) == 0: - if deferred_tool_calls: - raise UserError('At least one output type must be provided other than DeferredToolCalls.') - else: - raise UserError('At least one output type must be provided.') + allows_deferred_tool_calls = len(outputs) < len(raw_outputs) + if len(outputs) == 0 and allows_deferred_tool_calls: + raise UserError('At least one output type must be provided other than `DeferredToolCalls`.') if output := next((output for output in outputs if isinstance(output, NativeOutput)), None): if len(outputs) > 1: - raise UserError('NativeOutput cannot be mixed with other output types.') + raise UserError('`NativeOutput` must be the only output type.') return NativeOutputSchema( processor=cls._build_processor( @@ -207,11 +198,11 @@ def build( # noqa: C901 description=output.description, strict=output.strict, ), - deferred_tool_calls=deferred_tool_calls, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None): if len(outputs) > 1: - raise UserError('PromptedOutput cannot be mixed with other output types.') + raise UserError('`PromptedOutput` must be the only output type.') return PromptedOutputSchema( processor=cls._build_processor( @@ -220,7 +211,7 @@ def build( # noqa: C901 description=output.description, ), template=output.template, - deferred_tool_calls=deferred_tool_calls, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] @@ -233,53 +224,62 @@ def build( # noqa: C901 text_outputs.append(output) elif isinstance(output, ToolOutput): tool_outputs.append(output) - elif isinstance(output, (NativeOutput, PromptedOutput)): - # We can never get here because these are checked for above. - raise UserError('NativeOutput and PromptedOutput must be the only output types.') # pragma: no cover + elif isinstance(output, NativeOutput): + # We can never get here because this is checked for above. + raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover + elif isinstance(output, PromptedOutput): + # We can never get here because this is checked for above. + raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover else: other_outputs.append(output) - tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) + toolset = cls._build_toolset(tool_outputs + other_outputs, name=name, description=description, strict=strict) if len(text_outputs) > 0: if len(text_outputs) > 1: - raise UserError('Only one text output is allowed.') + raise UserError('Only one `str` or `TextOutput` is allowed.') text_output = text_outputs[0] text_output_schema = None if isinstance(text_output, TextOutput): text_output_schema = PlainTextOutputProcessor(text_output.output_function) - if len(tools) == 0: - return PlainTextOutputSchema(processor=text_output_schema, deferred_tool_calls=deferred_tool_calls) - else: + if toolset: return ToolOrTextOutputSchema( - processor=text_output_schema, tools=tools, deferred_tool_calls=deferred_tool_calls + processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls + ) + else: + return PlainTextOutputSchema( + processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls ) if len(tool_outputs) > 0: - return ToolOutputSchema(tools=tools, deferred_tool_calls=deferred_tool_calls) + return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls) if len(other_outputs) > 0: schema = OutputSchemaWithoutMode( processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), - tools=tools, - deferred_tool_calls=deferred_tool_calls, + toolset=toolset, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) if default_mode: schema = schema.with_default_mode(default_mode) return schema - raise UserError('No output type provided.') # pragma: no cover + raise UserError('At least one output type must be provided.') @staticmethod - def _build_tools( + def _build_toolset( outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], name: str | None = None, description: str | None = None, strict: bool | None = None, - ) -> dict[str, OutputTool[OutputDataT]]: - tools: dict[str, OutputTool[OutputDataT]] = {} + ) -> OutputToolset[Any] | None: + if len(outputs) == 0: + return None + + processors: dict[str, ObjectOutputProcessor[Any]] = {} + tool_defs: list[ToolDefinition] = [] default_name = name or DEFAULT_OUTPUT_TOOL_NAME default_description = description @@ -305,7 +305,7 @@ def _build_tools( i = 1 original_name = name - while name in tools: + while name in processors: i += 1 name = f'{original_name}_{i}' @@ -314,9 +314,26 @@ def _build_tools( strict = default_strict processor = ObjectOutputProcessor(output=output, description=description, strict=strict) - tools[name] = OutputTool(name=name, processor=processor, multiple=multiple) + object_def = processor.object_def - return tools + description = object_def.description + if not description: + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION + if multiple: + description = f'{object_def.name}: {description}' + + tool_def = ToolDefinition( + name=name, + description=description, + parameters_json_schema=object_def.json_schema, + strict=object_def.strict, + outer_typed_dict_key=processor.outer_typed_dict_key, + kind='output', + ) + processors[name] = processor + tool_defs.append(tool_def) + + return OutputToolset(processors=processors, tool_defs=tool_defs) @staticmethod def _build_processor( @@ -348,35 +365,39 @@ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDa @dataclass(init=False) class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] - _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + _toolset: OutputToolset[Any] | None = None def __init__( self, processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], - tools: dict[str, OutputTool[OutputDataT]], - deferred_tool_calls: bool, + toolset: OutputToolset[Any] | None, + allows_deferred_tool_calls: bool, ): - super().__init__(deferred_tool_calls) + super().__init__(allows_deferred_tool_calls) self.processor = processor - self._tools = tools + self._toolset = toolset def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: if mode == 'native': - return NativeOutputSchema(processor=self.processor, deferred_tool_calls=self.deferred_tool_calls) + return NativeOutputSchema( + processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls + ) elif mode == 'prompted': - return PromptedOutputSchema(processor=self.processor, deferred_tool_calls=self.deferred_tool_calls) + return PromptedOutputSchema( + processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls + ) elif mode == 'tool': - return ToolOutputSchema(tools=self.tools, deferred_tool_calls=self.deferred_tool_calls) + return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls) else: assert_never(mode) @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - # TODO: Update for toolsets - # We return tools here as they're checked in Agent._register_tool. - # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time. - return self._tools + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + # We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor. + # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time, + # but we cover ourselves just in case we end up using the tool output mode. + return self._toolset class TextOutputSchema(OutputSchema[OutputDataT], ABC): @@ -527,11 +548,11 @@ async def process( @dataclass(init=False) class ToolOutputSchema(OutputSchema[OutputDataT]): - _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + _toolset: OutputToolset[Any] | None = None - def __init__(self, tools: dict[str, OutputTool[OutputDataT]], deferred_tool_calls: bool): - super().__init__(deferred_tool_calls) - self._tools = tools + def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool): + super().__init__(allows_deferred_tool_calls) + self._toolset = toolset @property def mode(self) -> OutputMode: @@ -543,18 +564,9 @@ def raise_if_unsupported(self, profile: ModelProfile) -> None: raise UserError('Output tools are not supported by the model.') @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - return self._tools - - def find_named_tool( - self, parts: Iterable[_messages.ModelResponsePart], tool_name: str - ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: - """Find a tool that matches one of the calls, with a specific name.""" - for part in parts: # pragma: no branch - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if part.tool_name == tool_name: - return part, self.tools[tool_name] + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + return self._toolset @dataclass(init=False) @@ -562,10 +574,10 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem def __init__( self, processor: PlainTextOutputProcessor[OutputDataT] | None, - tools: dict[str, OutputTool[OutputDataT]], - deferred_tool_calls: bool, + toolset: OutputToolset[Any] | None, + allows_deferred_tool_calls: bool, ): - super().__init__(tools=tools, deferred_tool_calls=deferred_tool_calls) + super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls) self.processor = processor @property @@ -888,73 +900,46 @@ async def process( @dataclass(init=False) -class OutputTool(Generic[OutputDataT]): - processor: ObjectOutputProcessor[OutputDataT] - tool_def: ToolDefinition +class OutputToolset(AbstractToolset[AgentDepsT]): + """A toolset that contains output tools.""" - def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool): - self.processor = processor - object_def = processor.object_def + _tool_defs: list[ToolDefinition] + processors: dict[str, ObjectOutputProcessor[Any]] + max_retries: int = field(default=1) + output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list) - description = object_def.description - if not description: - description = DEFAULT_OUTPUT_TOOL_DESCRIPTION - if multiple: - description = f'{object_def.name}: {description}' + def __init__( + self, + tool_defs: list[ToolDefinition], + processors: dict[str, ObjectOutputProcessor[Any]], + max_retries: int = 1, + output_validators: list[OutputValidator[AgentDepsT, Any]] = [], + ): + self.processors = processors + self._tool_defs = tool_defs + self.max_retries = max_retries + self.output_validators = output_validators - self.tool_def = ToolDefinition( - name=name, - description=description, - parameters_json_schema=object_def.json_schema, - strict=object_def.strict, - outer_typed_dict_key=processor.outer_typed_dict_key, - kind='output', - ) + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + return RunToolset(self, ctx) - async def process( - self, - tool_call: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - allow_partial: bool = False, - wrap_validation_errors: bool = True, - ) -> OutputDataT: - """Process an output message. + @property + def tool_defs(self) -> list[ToolDefinition]: + return self._tool_defs - Args: - tool_call: The tool call from the LLM to validate. - run_context: The current run context. - allow_partial: If true, allow partial validation. - wrap_validation_errors: If true, wrap the validation errors in a retry message. + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.processors[name].validator - Returns: - Either the validated output data (left) or a retry message (right). - """ - try: - output = await self.processor.process( - tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False - ) - except ValidationError as e: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=tool_call.tool_call_id, - ) - raise ToolRetryError(m) from e - else: - raise # pragma: lax no cover - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=r.message, - tool_call_id=tool_call.tool_call_id, - ) - raise ToolRetryError(m) from r - else: - raise # pragma: lax no cover - else: - return output + def _max_retries_for_tool(self, name: str) -> int: + return self.max_retries + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + output = await self.processors[name].call(tool_args, ctx) + for validator in self.output_validators: + output = await validator.validate(output, None, ctx, wrap_validation_errors=False) + return output def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 52475ebcc..07891d9eb 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -14,8 +14,6 @@ from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated -from pydantic_ai.profiles import ModelProfile -from pydantic_ai.toolset import AbstractToolset, CombinedToolset, FunctionToolset, OutputToolset, PreparedToolset from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -31,8 +29,10 @@ usage as _usage, ) from ._agent_graph import HistoryProcessor +from ._output import OutputToolset from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from .output import OutputDataT, OutputSpec +from .profiles import ModelProfile from .result import FinalResult, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -48,6 +48,10 @@ ToolPrepareFunc, ToolsPrepareFunc, ) +from .toolsets import AbstractToolset +from .toolsets.combined import CombinedToolset +from .toolsets.function import FunctionToolset +from .toolsets.prepared import PreparedToolset # Re-exporting like this improves auto-import behavior in PyCharm capture_run_messages = _agent_graph.capture_run_messages @@ -154,9 +158,8 @@ class Agent(Generic[AgentDepsT, OutputDataT]): repr=False ) _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) - _output_toolset: OutputToolset[AgentDepsT] = dataclasses.field(repr=False) + _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) - _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) _toolset: AbstractToolset[AgentDepsT] = dataclasses.field(repr=False) _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) @@ -183,7 +186,6 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', @@ -193,7 +195,7 @@ def __init__( @overload @deprecated( - '`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' + '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' ) def __init__( self, @@ -215,7 +217,6 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', @@ -223,6 +224,35 @@ def __init__( history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, ) -> None: ... + @overload + @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + result_type: type[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, + result_tool_description: str | None = None, + result_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + mcp_servers: Sequence[MCPServer] = (), + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + ) -> None: ... + def __init__( self, model: models.Model | models.KnownModelName | str | None = None, @@ -242,9 +272,6 @@ def __init__( tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[ - MCPServer - ] = (), # TODO: Deprecate argument, MCPServers can be passed directly to toolsets toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', @@ -271,7 +298,7 @@ def __init__( when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow before raising an error. - output_retries: The maximum number of retries to allow for result validation, defaults to `retries`. + output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. @@ -280,9 +307,7 @@ def __init__( prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. This is useful if you want to customize the definition of multiple output tools or you want to register a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer] - for each server you want the agent to connect to. - toolsets: Toolsets to register with the agent. + toolsets: Toolsets to register with the agent, including MCP servers. defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` @@ -342,10 +367,18 @@ def __init__( warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning) output_retries = result_retries + if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): + warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) + if toolsets is None: + toolsets = mcp_servers + else: + toolsets = [*toolsets, *mcp_servers] + + _utils.validate_empty_kwargs(_deprecated_kwargs) + default_output_mode = ( self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None ) - _utils.validate_empty_kwargs(_deprecated_kwargs) self._output_schema = _output.OutputSchema[OutputDataT].build( output_type, @@ -374,16 +407,21 @@ def __init__( self._prepare_tools = prepare_tools self._prepare_output_tools = prepare_output_tools - self._output_toolset = OutputToolset(self._output_schema, max_retries=self._max_result_retries) + self._output_toolset = self._output_schema.toolset + if self._output_toolset: + self._output_toolset.max_retries = self._max_result_retries + self._function_toolset = FunctionToolset(tools, max_retries=retries) self._user_toolsets = toolsets or () - # TODO: Set max_retries on MCPServer - self._mcp_servers = mcp_servers + + all_toolsets: list[AbstractToolset[AgentDepsT]] = [] + if self._output_toolset: + all_toolsets.append(self._output_toolset) + all_toolsets.append(self._function_toolset) + all_toolsets.extend(self._user_toolsets) # This will raise errors for any name conflicts - self._toolset = CombinedToolset( - [self._output_toolset, self._function_toolset, *self._user_toolsets, *self._mcp_servers] - ) + self._toolset = CombinedToolset(all_toolsets) self.history_processors = history_processors or [] @@ -687,11 +725,12 @@ async def main(): output_toolset = self._output_toolset if output_schema != self._output_schema or output_validators: - output_toolset = OutputToolset[AgentDepsT]( - output_schema, max_retries=self._max_result_retries, output_validators=output_validators - ) - if self._prepare_output_tools: - output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) + output_toolset = output_schema.toolset + if output_toolset: + output_toolset.max_retries = self._max_result_retries + output_toolset.output_validators = output_validators + if self._prepare_output_tools: + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( @@ -717,10 +756,11 @@ async def main(): ) user_toolsets = self._user_toolsets if toolsets is None else toolsets - toolset = CombinedToolset([self._function_toolset, *user_toolsets, *self._mcp_servers]) + toolset = CombinedToolset([self._function_toolset, *user_toolsets]) if self._prepare_tools: toolset = PreparedToolset(toolset, self._prepare_tools) - toolset = CombinedToolset([output_toolset, toolset]) + if output_toolset: + toolset = CombinedToolset([output_toolset, toolset]) run_toolset = await toolset.prepare_for_run(run_context) model_settings = merge_model_settings(self.model_settings, model_settings) @@ -1120,22 +1160,17 @@ async def on_complete() -> None: part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) ] - # TODO: Should we move on to the CallToolsNode here, instead of doing this ourselves? parts: list[_messages.ModelRequestPart] = [] - # final_result_holder: list[result.FinalResult[models.StreamedResponse]] = [] async for _event in _agent_graph.process_function_tools( graph_ctx.deps.toolset, tool_calls, final_result, graph_ctx, parts, - # final_result_holder, ): pass if parts: messages.append(_messages.ModelRequest(parts)) - # if final_result_holder: - # final_result = final_result_holder[0] yield StreamedRunResult( messages, @@ -1150,7 +1185,6 @@ async def on_complete() -> None: graph_ctx.deps.toolset, ) break - # TODO: There may be deferred tool calls, process those. next_node = await agent_run.next(node) if not isinstance(next_node, _agent_graph.AgentNode): raise exceptions.AgentRunError( # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 078347825..01f599a9f 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -2,12 +2,16 @@ import json import sys +from typing import TYPE_CHECKING if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup # pragma: lax no cover else: ExceptionGroup = ExceptionGroup # pragma: lax no cover +if TYPE_CHECKING: + from .messages import RetryPromptPart + __all__ = ( 'ModelRetry', 'UserError', @@ -113,3 +117,11 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None class FallbackExceptionGroup(ExceptionGroup): """A group of exceptions that can be raised when all fallback models fail.""" + + +class ToolRetryError(Exception): + """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" + + def __init__(self, tool_retry: RetryPromptPart): + self.tool_retry = tool_retry + super().__init__() diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 1512ee4f7..83fc6b146 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -3,7 +3,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai.tools import Tool -from pydantic_ai.toolset import FunctionToolset +from pydantic_ai.toolsets.function import FunctionToolset class LangChainTool(Protocol): diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 8f980c246..ab07ed33e 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -20,7 +20,10 @@ from pydantic_ai.tools import ToolDefinition from .exceptions import UserError -from .toolset import AbstractToolset, PrefixedToolset, ProcessedToolset, RunToolset, ToolProcessFunc +from .toolsets import AbstractToolset +from .toolsets.prefixed import PrefixedToolset +from .toolsets.processed import ProcessedToolset, ToolProcessFunc +from .toolsets.run import RunToolset try: from mcp import types as mcp_types @@ -56,6 +59,7 @@ class MCPServer(AbstractToolset[Any], ABC): timeout: float = 5 process_tool_call: ToolProcessFunc[Any] | None = None allow_sampling: bool = True + max_retries: int = 1 # } end of "abstract fields" _running_count: int = 0 @@ -90,7 +94,7 @@ def name(self) -> str: return repr(self) @property - def name_conflict_hint(self) -> str: + def tool_name_conflict_hint(self) -> str: return 'Consider setting `tool_prefix` to avoid name conflicts.' async def list_tools(self) -> list[mcp_types.Tool]: @@ -169,6 +173,7 @@ async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]: @property def tool_defs(self) -> list[ToolDefinition]: + # The actual tool definitions are loaded in `prepare_for_run` and cached on the `RunToolset` that will wrap us return [] async def list_tool_defs(self) -> list[ToolDefinition]: @@ -190,7 +195,7 @@ def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_ ) def _max_retries_for_tool(self, name: str) -> int: - return 1 + return self.max_retries def set_mcp_sampling_model(self, model: models.Model) -> None: self.sampling_model = model @@ -317,7 +322,7 @@ class MCPServerStdio(MCPServer): 'stdio', ] ) - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): async with agent.run_toolsets(): # (2)! @@ -373,6 +378,9 @@ async def main(): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + @asynccontextmanager async def client_streams( self, @@ -468,6 +476,9 @@ class _MCPServerHTTP(MCPServer): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + @property @abstractmethod def _transport_client( @@ -549,7 +560,7 @@ class MCPServerSSE(_MCPServerHTTP): from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE('http://localhost:3001/sse') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): async with agent.run_toolsets(): # (2)! @@ -583,7 +594,7 @@ class MCPServerHTTP(MCPServerSSE): from pydantic_ai.mcp import MCPServerHTTP server = MCPServerHTTP('http://localhost:3001/sse') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): async with agent.run_toolsets(): # (2)! @@ -612,7 +623,7 @@ class MCPServerStreamableHTTP(_MCPServerHTTP): from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): async with agent.run_toolsets(): # (2)! diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 90f830cb8..316921781 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -295,7 +295,7 @@ def split_into_words(text: str) -> list[str]: @dataclass class DeferredToolCalls: - """Output type for calls to tools defined as pending.""" + """Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.""" tool_calls: list[ToolCallPart] tool_defs: dict[str, ToolDefinition] diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 366392d34..2d1274318 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -3,14 +3,14 @@ import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from copy import copy -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from datetime import datetime from typing import Generic, cast from pydantic import ValidationError from typing_extensions import TypeVar, deprecated, overload -from pydantic_ai.toolset import RunToolset +from pydantic_ai.toolsets.run import RunToolset from . import _utils, exceptions, messages as _messages, models from ._output import ( @@ -93,20 +93,26 @@ async def _validate_response( self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - call = None if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None: - match = self._output_schema.find_named_tool(message.parts, output_tool_name) - if match is None: + tool_call = next( + ( + part + for part in message.parts + if isinstance(part, _messages.ToolCallPart) and part.tool_name == output_tool_name + ), + None, + ) + if tool_call is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - 'Invalid response, unable to find tool' + f'Invalid response, unable to find tool call for {output_tool_name!r}' ) - - call, output_tool = match - result_data = await output_tool.process( - call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + run_context = replace(self._run_ctx, tool_call_id=tool_call.tool_call_id) + args_dict = self._toolset.validate_tool_args( + run_context, tool_call.tool_name, tool_call.args, allow_partial=allow_partial ) + return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict) elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): - if not self._output_schema.deferred_tool_calls: + if not self._output_schema.allows_deferred_tool_calls: raise exceptions.UserError( 'There are deferred tool calls but DeferredToolCalls is not among output types.' ) @@ -117,15 +123,14 @@ async def _validate_response( result_data = await self._output_schema.process( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, None, self._run_ctx) + return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) - return result_data - def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -424,20 +429,26 @@ async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - call = None if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: - match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) - if match is None: + tool_call = next( + ( + part + for part in message.parts + if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name + ), + None, + ) + if tool_call is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - 'Invalid response, unable to find tool' + f'Invalid response, unable to find tool call for {self._output_tool_name!r}' ) - - call, output_tool = match - result_data = await output_tool.process( - call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + run_context = replace(self._run_ctx, tool_call_id=tool_call.tool_call_id) + args_dict = self._toolset.validate_tool_args( + run_context, tool_call.tool_name, tool_call.args, allow_partial=allow_partial ) + return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict) elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): - if not self._output_schema.deferred_tool_calls: + if not self._output_schema.allows_deferred_tool_calls: raise exceptions.UserError( 'There are deferred tool calls but DeferredToolCalls is not among output types.' ) @@ -448,15 +459,14 @@ async def validate_structured_output( result_data = await self._output_schema.process( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, None, self._run_ctx) # pragma: no cover + return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover - return result_data - async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: text = await validator.validate(text, None, self._run_ctx) # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 201081c5f..782f11b09 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -309,11 +309,11 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition Returns: return a `ToolDefinition` or `None` if the tools should not be registered for this run. """ - standard_tool_def = self.tool_def + base_tool_def = self.tool_def if self.prepare is not None: - return await self.prepare(ctx, standard_tool_def) + return await self.prepare(ctx, base_tool_def) else: - return standard_tool_def + return base_tool_def ObjectJsonSchema: TypeAlias = dict[str, Any] @@ -363,6 +363,11 @@ class ToolDefinition: """ kind: ToolKind = field(default='function') - """The kind of tool.""" + """The kind of tool: + - `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model + - `'output'`: a tool that passes through an output value that ends the run + - `'deferred'`: a tool that cannot be executed by Pydantic AI and needs to get a result from the outside. + When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s for each deferred call. + """ __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/toolset.py b/pydantic_ai_slim/pydantic_ai/toolset.py deleted file mode 100644 index 62efd7f7d..000000000 --- a/pydantic_ai_slim/pydantic_ai/toolset.py +++ /dev/null @@ -1,769 +0,0 @@ -from __future__ import annotations - -import asyncio -from abc import ABC, abstractmethod -from collections.abc import Awaitable, Iterable, Iterator, Sequence -from contextlib import AsyncExitStack, contextmanager -from dataclasses import dataclass, field, replace -from functools import partial -from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, overload - -from pydantic import ValidationError -from pydantic.json_schema import GenerateJsonSchema -from pydantic_core import SchemaValidator -from typing_extensions import Self - -from pydantic_ai.output import DeferredToolCalls - -from . import messages as _messages -from ._output import BaseOutputSchema, OutputValidator, ToolRetryError -from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError -from .tools import ( - DocstringFormat, - GenerateToolJsonSchema, - Tool, - ToolDefinition, - ToolFuncEither, - ToolParams, - ToolPrepareFunc, - ToolsPrepareFunc, -) - -if TYPE_CHECKING: - from pydantic_ai.models import Model - - -class AbstractToolset(ABC, Generic[AgentDepsT]): - """A toolset is a collection of tools that can be used by an agent. - - It is responsible for: - - Listing the tools it contains - - Validating the arguments of the tools - - Calling the tools - """ - - @property - def name(self) -> str: - return self.__class__.__name__.replace('Toolset', ' toolset') - - @property - def name_conflict_hint(self) -> str: - return 'Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.' - - async def __aenter__(self) -> Self: - return self - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None - ) -> bool | None: - return None - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - return RunToolset(self, ctx) - - @property - @abstractmethod - def tool_defs(self) -> list[ToolDefinition]: - raise NotImplementedError() - - @property - def tool_names(self) -> list[str]: - return [tool_def.name for tool_def in self.tool_defs] - - def get_tool_def(self, name: str) -> ToolDefinition | None: - return next((tool_def for tool_def in self.tool_defs if tool_def.name == name), None) - - @abstractmethod - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - raise NotImplementedError() - - def validate_tool_args( - self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False - ) -> dict[str, Any]: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - validator = self._get_tool_args_validator(ctx, name) - if isinstance(args, str): - return validator.validate_json(args or '{}', allow_partial=pyd_allow_partial) - else: - return validator.validate_python(args or {}, allow_partial=pyd_allow_partial) - - @abstractmethod - def _max_retries_for_tool(self, name: str) -> int: - raise NotImplementedError() - - @abstractmethod - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - raise NotImplementedError() - - def set_mcp_sampling_model(self, model: Model) -> None: - pass - - -@dataclass(init=False) -class FunctionToolset(AbstractToolset[AgentDepsT]): - """A toolset that functions can be registered to as tools.""" - - max_retries: int = field(default=1) - tools: dict[str, Tool[Any]] = field(default_factory=dict) - - def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): - self.max_retries = max_retries - self.tools = {} - for tool in tools: - if isinstance(tool, Tool): - self.register_tool(tool) - else: - self.register_function(tool) - - @overload - def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... - - @overload - def tool( - self, - /, - *, - name: str | None = None, - retries: int | None = None, - prepare: ToolPrepareFunc[AgentDepsT] | None = None, - docstring_format: DocstringFormat = 'auto', - require_parameter_descriptions: bool = False, - schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, - strict: bool | None = None, - ) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ... - - def tool( - self, - func: ToolFuncEither[AgentDepsT, ToolParams] | None = None, - /, - *, - name: str | None = None, - retries: int | None = None, - prepare: ToolPrepareFunc[AgentDepsT] | None = None, - docstring_format: DocstringFormat = 'auto', - require_parameter_descriptions: bool = False, - schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, - strict: bool | None = None, - ) -> Any: - """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. - - Can decorate a sync or async functions. - - The docstring is inspected to extract both the tool description and description of each parameter, - [learn more](../tools.md#function-tools-and-schema). - - We can't add overloads for every possible signature of tool, since the return type is a recursive union - so the signature of functions decorated with `@agent.tool` is obscured. - - Example: - ```python - from pydantic_ai import Agent, RunContext - - agent = Agent('test', deps_type=int) - - @agent.tool - def foobar(ctx: RunContext[int], x: int) -> int: - return ctx.deps + x - - @agent.tool(retries=2) - async def spam(ctx: RunContext[str], y: float) -> float: - return ctx.deps + y - - result = agent.run_sync('foobar', deps=1) - print(result.output) - #> {"foobar":1,"spam":1.0} - ``` - - Args: - func: The tool function to register. - name: The name of the tool, defaults to the function name. - retries: The number of retries to allow for this tool, defaults to the agent's default retries, - which defaults to 1. - prepare: custom method to prepare the tool definition for each step, return `None` to omit this - tool from a given step. This is useful if you want to customise a tool at call time, - or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. - docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. - Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. - require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. - schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. - strict: Whether to enforce JSON schema compliance (only affects OpenAI). - See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. - """ - if func is None: - - def tool_decorator( - func_: ToolFuncEither[AgentDepsT, ToolParams], - ) -> ToolFuncEither[AgentDepsT, ToolParams]: - # noinspection PyTypeChecker - self.register_function( - func_, - None, - name, - retries, - prepare, - docstring_format, - require_parameter_descriptions, - schema_generator, - strict, - ) - return func_ - - return tool_decorator - else: - # noinspection PyTypeChecker - self.register_function( - func, - None, - name, - retries, - prepare, - docstring_format, - require_parameter_descriptions, - schema_generator, - strict, - ) - return func - - def register_function( - self, - func: ToolFuncEither[AgentDepsT, ToolParams], - takes_ctx: bool | None = None, - name: str | None = None, - retries: int | None = None, - prepare: ToolPrepareFunc[AgentDepsT] | None = None, - docstring_format: DocstringFormat = 'auto', - require_parameter_descriptions: bool = False, - schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, - strict: bool | None = None, - ) -> None: - """Register a function as a tool.""" - tool = Tool[AgentDepsT]( - func, - takes_ctx=takes_ctx, - name=name, - max_retries=retries, - prepare=prepare, - docstring_format=docstring_format, - require_parameter_descriptions=require_parameter_descriptions, - schema_generator=schema_generator, - strict=strict, - ) - self.register_tool(tool) - - def register_tool(self, tool: Tool[AgentDepsT]) -> None: - if tool.name in self.tools: - raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}') - if tool.max_retries is None: - tool.max_retries = self.max_retries - self.tools[tool.name] = tool - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - self_for_run = RunToolset(self, ctx) - prepared_for_run = await IndividuallyPreparedToolset(self_for_run, self._prepare_tool_def).prepare_for_run(ctx) - return RunToolset(prepared_for_run, ctx, original=self) - - async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: - tool_name = tool_def.name - ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) - return await self.tools[tool_name].prepare_tool_def(ctx) - - @property - def tool_defs(self) -> list[ToolDefinition]: - return [tool.tool_def for tool in self.tools.values()] - - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self.tools[name].function_schema.validator - - def _max_retries_for_tool(self, name: str) -> int: - tool = self.tools[name] - return tool.max_retries if tool.max_retries is not None else self.max_retries - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await self.tools[name].function_schema.call(tool_args, ctx) - - -@dataclass -class OutputToolset(AbstractToolset[AgentDepsT]): - """A toolset that contains output tools.""" - - output_schema: BaseOutputSchema[Any] - max_retries: int = field(default=1) # TODO: Test this works - output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list) - - @property - def tool_defs(self) -> list[ToolDefinition]: - return [tool.tool_def for tool in self.output_schema.tools.values()] - - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self.output_schema.tools[name].processor.validator - - def _max_retries_for_tool(self, name: str) -> int: - return self.max_retries - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - output = await self.output_schema.tools[name].processor.call(tool_args, ctx) - for validator in self.output_validators: - output = await validator.validate(output, None, ctx, wrap_validation_errors=False) - return output - - -class DeferredToolset(AbstractToolset[AgentDepsT]): - """A toolset that holds deferred tool.""" - - _tool_defs: list[ToolDefinition] - - def __init__(self, tool_defs: list[ToolDefinition]): - self._tool_defs = tool_defs - - @property - def tool_defs(self) -> list[ToolDefinition]: - return [replace(tool_def, kind='deferred') for tool_def in self._tool_defs] - - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - raise NotImplementedError('Deferred tools cannot be validated') - - def _max_retries_for_tool(self, name: str) -> int: - raise NotImplementedError('Deferred tools cannot be retried') - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - raise NotImplementedError('Deferred tools cannot be called') - - -@dataclass(init=False) -class CombinedToolset(AbstractToolset[AgentDepsT]): - """A toolset that combines multiple toolsets.""" - - toolsets: list[AbstractToolset[AgentDepsT]] - _exit_stack: AsyncExitStack | None - _toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]] - - def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): - self._exit_stack = None - self.toolsets = list(toolsets) - - self._toolset_per_tool_name = {} - for toolset in self.toolsets: - for name in toolset.tool_names: - try: - existing_toolset = self._toolset_per_tool_name[name] - raise UserError( - f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_toolset.name}: {name!r}. {toolset.name_conflict_hint}' - ) - except KeyError: - pass - self._toolset_per_tool_name[name] = toolset - - async def __aenter__(self) -> Self: - # TODO: running_count thing like in MCPServer? - self._exit_stack = AsyncExitStack() - for toolset in self.toolsets: - await self._exit_stack.enter_async_context(toolset) - return self - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None - ) -> bool | None: - if self._exit_stack is not None: - await self._exit_stack.aclose() - self._exit_stack = None - return None - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - toolsets_for_run = await asyncio.gather(*[toolset.prepare_for_run(ctx) for toolset in self.toolsets]) - combined_for_run = CombinedToolset(toolsets_for_run) - return RunToolset(combined_for_run, ctx) - - @property - def tool_defs(self) -> list[ToolDefinition]: - return [tool_def for toolset in self.toolsets for tool_def in toolset.tool_defs] - - @property - def tool_names(self) -> list[str]: - return list(self._toolset_per_tool_name.keys()) - - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self._toolset_for_tool_name(name)._get_tool_args_validator(ctx, name) - - def validate_tool_args( - self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False - ) -> dict[str, Any]: - return self._toolset_for_tool_name(name).validate_tool_args(ctx, name, args, allow_partial) - - def _max_retries_for_tool(self, name: str) -> int: - return self._toolset_for_tool_name(name)._max_retries_for_tool(name) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs) - - def set_mcp_sampling_model(self, model: Model) -> None: - for toolset in self.toolsets: - toolset.set_mcp_sampling_model(model) - - def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]: - try: - return self._toolset_per_tool_name[name] - except KeyError as e: - raise ValueError(f'Tool {name!r} not found in any toolset') from e - - -@dataclass -class WrapperToolset(AbstractToolset[AgentDepsT], ABC): - """A toolset that wraps another toolset and delegates to it.""" - - wrapped: AbstractToolset[AgentDepsT] - - @property - def name(self) -> str: - return self.wrapped.name - - @property - def name_conflict_hint(self) -> str: - return self.wrapped.name_conflict_hint - - async def __aenter__(self) -> Self: - await self.wrapped.__aenter__() - return self - - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None - ) -> bool | None: - return await self.wrapped.__aexit__(exc_type, exc_value, traceback) - - @abstractmethod - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - raise NotImplementedError() - - @property - def tool_defs(self) -> list[ToolDefinition]: - return self.wrapped.tool_defs - - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self.wrapped._get_tool_args_validator(ctx, name) - - def _max_retries_for_tool(self, name: str) -> int: - return self.wrapped._max_retries_for_tool(name) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) - - def set_mcp_sampling_model(self, model: Model) -> None: - self.wrapped.set_mcp_sampling_model(model) - - def __getattr__(self, item: str): - return getattr(self.wrapped, item) # pragma: no cover - - -@dataclass -class PrefixedToolset(WrapperToolset[AgentDepsT]): - """A toolset that prefixes the names of the tools it contains.""" - - prefix: str - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - wrapped_for_run = await self.wrapped.prepare_for_run(ctx) - prefixed_for_run = PrefixedToolset(wrapped_for_run, self.prefix) - return RunToolset(prefixed_for_run, ctx) - - @property - def tool_defs(self) -> list[ToolDefinition]: - return [replace(tool_def, name=self._prefixed_tool_name(tool_def.name)) for tool_def in super().tool_defs] - - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return super()._get_tool_args_validator(ctx, self._unprefixed_tool_name(name)) - - def _max_retries_for_tool(self, name: str) -> int: - return super()._max_retries_for_tool(self._unprefixed_tool_name(name)) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await super().call_tool(ctx, self._unprefixed_tool_name(name), tool_args, *args, **kwargs) - - def _prefixed_tool_name(self, tool_name: str) -> str: - return f'{self.prefix}_{tool_name}' - - def _unprefixed_tool_name(self, tool_name: str) -> str: - full_prefix = f'{self.prefix}_' - if not tool_name.startswith(full_prefix): - raise ValueError(f"Tool name '{tool_name}' does not start with prefix '{full_prefix}'") - return tool_name[len(full_prefix) :] - - -@dataclass -class PreparedToolset(WrapperToolset[AgentDepsT]): - """A toolset that prepares the tools it contains using a prepare function.""" - - prepare_func: ToolsPrepareFunc[AgentDepsT] - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - wrapped_for_run = await self.wrapped.prepare_for_run(ctx) - original_tool_defs = wrapped_for_run.tool_defs - prepared_tool_defs = await self.prepare_func(ctx, original_tool_defs) or [] - - original_tool_names = {tool_def.name for tool_def in original_tool_defs} - prepared_tool_names = {tool_def.name for tool_def in prepared_tool_defs} - if len(prepared_tool_names - original_tool_names) > 0: - raise UserError('Prepare function is not allowed to change tool names or add new tools.') - - prepared_for_run = PreparedToolset(wrapped_for_run, self.prepare_func) - return RunToolset(prepared_for_run, ctx, prepared_tool_defs) - - -@dataclass(init=False) -class MappedToolset(WrapperToolset[AgentDepsT]): - """A toolset that maps the names of the tools it contains.""" - - name_map: dict[str, str] - _tool_defs: list[ToolDefinition] - - def __init__( - self, - wrapped: AbstractToolset[AgentDepsT], - tool_defs: list[ToolDefinition], - name_map: dict[str, str], - ): - super().__init__(wrapped) - self._tool_defs = tool_defs - self.name_map = name_map - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - wrapped_for_run = await self.wrapped.prepare_for_run(ctx) - mapped_for_run = MappedToolset(wrapped_for_run, self._tool_defs, self.name_map) - return RunToolset(mapped_for_run, ctx) - - @property - def tool_defs(self) -> list[ToolDefinition]: - return self._tool_defs - - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return super()._get_tool_args_validator(ctx, self._map_name(name)) - - def _max_retries_for_tool(self, name: str) -> int: - return super()._max_retries_for_tool(self._map_name(name)) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await super().call_tool(ctx, self._map_name(name), tool_args, *args, **kwargs) - - def _map_name(self, name: str) -> str: - return self.name_map.get(name, name) - - -@dataclass -class IndividuallyPreparedToolset(WrapperToolset[AgentDepsT]): - """A toolset that prepares the tools it contains using a per-tool prepare function.""" - - prepare_func: ToolPrepareFunc[AgentDepsT] - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - wrapped_for_run = await self.wrapped.prepare_for_run(ctx) - - tool_defs: dict[str, ToolDefinition] = {} - name_map: dict[str, str] = {} - for original_tool_def in wrapped_for_run.tool_defs: - original_name = original_tool_def.name - tool_def = await self.prepare_func(ctx, original_tool_def) - if not tool_def: - continue - - new_name = tool_def.name - if new_name in tool_defs: - if new_name != original_name: - raise UserError(f"Renaming tool '{original_name}' to '{new_name}' conflicts with existing tool.") - else: - raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') - name_map[new_name] = original_name - - tool_defs[new_name] = tool_def - - mapped_for_run = await MappedToolset(wrapped_for_run, list(tool_defs.values()), name_map).prepare_for_run(ctx) - return RunToolset(mapped_for_run, ctx, original=self) - - -@dataclass(init=False) -class FilteredToolset(IndividuallyPreparedToolset[AgentDepsT]): - """A toolset that filters the tools it contains using a filter function.""" - - def __init__( - self, - toolset: AbstractToolset[AgentDepsT], - filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool], - ): - async def filter_tool_def(ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: - return tool_def if filter_func(ctx, tool_def) else None - - super().__init__(toolset, filter_tool_def) - - -class CallToolFunc(Protocol): - """A function protocol that represents a tool call.""" - - def __call__(self, name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any) -> Awaitable[Any]: ... - - -ToolProcessFunc = Callable[ - [ - RunContext[AgentDepsT], - CallToolFunc, - str, - dict[str, Any], - ], - Awaitable[Any], -] - - -@dataclass -class ProcessedToolset(WrapperToolset[AgentDepsT]): - """A toolset that lets the tool call arguments and return value be customized using a process function.""" - - process: ToolProcessFunc[AgentDepsT] - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - wrapped_for_run = await self.wrapped.prepare_for_run(ctx) - processed = ProcessedToolset(wrapped_for_run, self.process) - return RunToolset(processed, ctx) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await self.process(ctx, partial(self.wrapped.call_tool, ctx), name, tool_args, *args, **kwargs) - - -@dataclass(init=False) -class RunToolset(WrapperToolset[AgentDepsT]): - """A toolset that is frozen for a specific run.""" - - ctx: RunContext[AgentDepsT] - _tool_defs: list[ToolDefinition] - _tool_names: list[str] - _retries: dict[str, int] - _original: AbstractToolset[AgentDepsT] - - def __init__( - self, - wrapped: AbstractToolset[AgentDepsT], - ctx: RunContext[AgentDepsT], - tool_defs: list[ToolDefinition] | None = None, - original: AbstractToolset[AgentDepsT] | None = None, - ): - self.wrapped = wrapped - self.ctx = ctx - self._tool_defs = wrapped.tool_defs if tool_defs is None else tool_defs - self._tool_names = [tool_def.name for tool_def in self._tool_defs] - self._retries = ctx.retries.copy() - self._original = original or wrapped - - @property - def name(self) -> str: - return self.wrapped.name - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - if ctx == self.ctx: - return self - else: - ctx = replace(ctx, retries=self._retries) - return await self._original.prepare_for_run(ctx) - - @property - def tool_defs(self) -> list[ToolDefinition]: - return self._tool_defs - - @property - def tool_names(self) -> list[str]: - return self._tool_names - - def validate_tool_args( - self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False - ) -> dict[str, Any]: - with self._with_retry(name, ctx) as ctx: - return super().validate_tool_args(ctx, name, args, allow_partial) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - with self._with_retry(name, ctx) as ctx: - try: - output = await super().call_tool(ctx, name, tool_args, *args, **kwargs) - except Exception as e: - raise e - else: - self._retries.pop(name, None) - return output - - def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None: - deferred_calls_and_defs = [ - (part, tool_def) - for part in parts - if isinstance(part, _messages.ToolCallPart) - and (tool_def := self.get_tool_def(part.tool_name)) - and tool_def.kind == 'deferred' - ] - if not deferred_calls_and_defs: - return None - - deferred_calls: list[_messages.ToolCallPart] = [] - deferred_tool_defs: dict[str, ToolDefinition] = {} - for part, tool_def in deferred_calls_and_defs: - deferred_calls.append(part) - deferred_tool_defs[part.tool_name] = tool_def - - return DeferredToolCalls(deferred_calls, deferred_tool_defs) - - @contextmanager - def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunContext[AgentDepsT]]: - try: - if name not in self.tool_names: - if self.tool_names: - msg = f'Available tools: {", ".join(self.tool_names)}' - else: - msg = 'No tools available.' - raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') - - ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={}) - yield ctx - except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e: - try: - max_retries = self._max_retries_for_tool(name) - except Exception: - max_retries = 1 - current_retry = self._retries.get(name, 0) - - if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None: - e = e.__cause__ - - if current_retry == max_retries: - raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e - else: - if ctx.tool_call_id: - if isinstance(e, ValidationError): - m = _messages.RetryPromptPart( - tool_name=name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=ctx.tool_call_id, - ) - e = ToolRetryError(m) - elif isinstance(e, ModelRetry): - m = _messages.RetryPromptPart( - tool_name=name, - content=e.message, - tool_call_id=ctx.tool_call_id, - ) - e = ToolRetryError(m) - - self._retries[name] = current_retry + 1 - raise e diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py new file mode 100644 index 000000000..e914d0001 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from types import TracebackType +from typing import TYPE_CHECKING, Any, Generic, Literal + +from pydantic_core import SchemaValidator +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition + +if TYPE_CHECKING: + from ..models import Model + from .run import RunToolset + + +class AbstractToolset(ABC, Generic[AgentDepsT]): + """A toolset is a collection of tools that can be used by an agent. + + It is responsible for: + - Listing the tools it contains + - Validating the arguments of the tools + - Calling the tools + """ + + @property + def name(self) -> str: + return self.__class__.__name__.replace('Toolset', ' toolset') + + @property + def tool_name_conflict_hint(self) -> str: + return 'Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.' + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + return None + + @abstractmethod + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + raise NotImplementedError() + + @property + @abstractmethod + def tool_defs(self) -> list[ToolDefinition]: + raise NotImplementedError() + + @property + def tool_names(self) -> list[str]: + return [tool_def.name for tool_def in self.tool_defs] + + def get_tool_def(self, name: str) -> ToolDefinition | None: + return next((tool_def for tool_def in self.tool_defs if tool_def.name == name), None) + + @abstractmethod + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + raise NotImplementedError() + + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False + ) -> dict[str, Any]: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + validator = self._get_tool_args_validator(ctx, name) + if isinstance(args, str): + return validator.validate_json(args or '{}', allow_partial=pyd_allow_partial) + else: + return validator.validate_python(args or {}, allow_partial=pyd_allow_partial) + + @abstractmethod + def _max_retries_for_tool(self, name: str) -> int: + raise NotImplementedError() + + @abstractmethod + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + raise NotImplementedError() + + def set_mcp_sampling_model(self, model: Model) -> None: + pass diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py new file mode 100644 index 000000000..835718754 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from contextlib import AsyncExitStack +from dataclasses import dataclass +from types import TracebackType +from typing import TYPE_CHECKING, Any + +from pydantic_core import SchemaValidator +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ToolDefinition +from . import AbstractToolset +from .run import RunToolset + +if TYPE_CHECKING: + from ..models import Model + + +@dataclass(init=False) +class CombinedToolset(AbstractToolset[AgentDepsT]): + """A toolset that combines multiple toolsets.""" + + toolsets: list[AbstractToolset[AgentDepsT]] + _exit_stack: AsyncExitStack | None + _toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]] + + def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): + self._exit_stack = None + self.toolsets = list(toolsets) + + self._toolset_per_tool_name = {} + for toolset in self.toolsets: + for name in toolset.tool_names: + try: + existing_toolset = self._toolset_per_tool_name[name] + raise UserError( + f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}' + ) + except KeyError: + pass + self._toolset_per_tool_name[name] = toolset + + async def __aenter__(self) -> Self: + # TODO: running_count thing like in MCPServer? + self._exit_stack = AsyncExitStack() + for toolset in self.toolsets: + await self._exit_stack.enter_async_context(toolset) + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + if self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + return None + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + toolsets_for_run = await asyncio.gather(*[toolset.prepare_for_run(ctx) for toolset in self.toolsets]) + combined_for_run = CombinedToolset(toolsets_for_run) + return RunToolset(combined_for_run, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [tool_def for toolset in self.toolsets for tool_def in toolset.tool_defs] + + @property + def tool_names(self) -> list[str]: + return list(self._toolset_per_tool_name.keys()) + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self._toolset_for_tool_name(name)._get_tool_args_validator(ctx, name) + + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False + ) -> dict[str, Any]: + return self._toolset_for_tool_name(name).validate_tool_args(ctx, name, args, allow_partial) + + def _max_retries_for_tool(self, name: str) -> int: + return self._toolset_for_tool_name(name)._max_retries_for_tool(name) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs) + + def set_mcp_sampling_model(self, model: Model) -> None: + for toolset in self.toolsets: + toolset.set_mcp_sampling_model(model) + + def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]: + try: + return self._toolset_per_tool_name[name] + except KeyError as e: + raise ValueError(f'Tool {name!r} not found in any toolset') from e diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py new file mode 100644 index 000000000..1bc69628b --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import replace +from typing import Any + +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from . import AbstractToolset +from .run import RunToolset + + +class DeferredToolset(AbstractToolset[AgentDepsT]): + """A toolset that holds deferred tool.""" + + _tool_defs: list[ToolDefinition] + + def __init__(self, tool_defs: list[ToolDefinition]): + self._tool_defs = tool_defs + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + return RunToolset(self, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [replace(tool_def, kind='deferred') for tool_def in self._tool_defs] + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + raise NotImplementedError('Deferred tools cannot be validated') + + def _max_retries_for_tool(self, name: str) -> int: + raise NotImplementedError('Deferred tools cannot be retried') + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + raise NotImplementedError('Deferred tools cannot be called') diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py new file mode 100644 index 000000000..6a6789620 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from . import AbstractToolset +from .individually_prepared import IndividuallyPreparedToolset + + +@dataclass(init=False) +class FilteredToolset(IndividuallyPreparedToolset[AgentDepsT]): + """A toolset that filters the tools it contains using a filter function.""" + + def __init__( + self, + toolset: AbstractToolset[AgentDepsT], + filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool], + ): + async def filter_tool_def(ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: + return tool_def if filter_func(ctx, tool_def) else None + + super().__init__(toolset, filter_tool_def) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py new file mode 100644 index 000000000..0f5261891 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field, replace +from typing import Any, Callable, overload + +from pydantic.json_schema import GenerateJsonSchema +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ( + DocstringFormat, + GenerateToolJsonSchema, + Tool, + ToolDefinition, + ToolFuncEither, + ToolParams, + ToolPrepareFunc, +) +from . import AbstractToolset +from .individually_prepared import IndividuallyPreparedToolset +from .run import RunToolset + + +@dataclass(init=False) +class FunctionToolset(AbstractToolset[AgentDepsT]): + """A toolset that functions can be registered to as tools.""" + + max_retries: int = field(default=1) + tools: dict[str, Tool[Any]] = field(default_factory=dict) + + def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + self.max_retries = max_retries + self.tools = {} + for tool in tools: + if isinstance(tool, Tool): + self.register_tool(tool) + else: + self.register_function(tool) + + @overload + def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... + + @overload + def tool( + self, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ... + + def tool( + self, + func: ToolFuncEither[AgentDepsT, ToolParams] | None = None, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Any: + """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. + + Can decorate a sync or async functions. + + The docstring is inspected to extract both the tool description and description of each parameter, + [learn more](../tools.md#function-tools-and-schema). + + We can't add overloads for every possible signature of tool, since the return type is a recursive union + so the signature of functions decorated with `@agent.tool` is obscured. + + Example: + ```python + from pydantic_ai import Agent, RunContext + + agent = Agent('test', deps_type=int) + + @agent.tool + def foobar(ctx: RunContext[int], x: int) -> int: + return ctx.deps + x + + @agent.tool(retries=2) + async def spam(ctx: RunContext[str], y: float) -> float: + return ctx.deps + y + + result = agent.run_sync('foobar', deps=1) + print(result.output) + #> {"foobar":1,"spam":1.0} + ``` + + Args: + func: The tool function to register. + name: The name of the tool, defaults to the function name. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, + which defaults to 1. + prepare: custom method to prepare the tool definition for each step, return `None` to omit this + tool from a given step. This is useful if you want to customise a tool at call time, + or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. + docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. + Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. + require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. + schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. + strict: Whether to enforce JSON schema compliance (only affects OpenAI). + See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + """ + if func is None: + + def tool_decorator( + func_: ToolFuncEither[AgentDepsT, ToolParams], + ) -> ToolFuncEither[AgentDepsT, ToolParams]: + # noinspection PyTypeChecker + self.register_function( + func_, + None, + name, + retries, + prepare, + docstring_format, + require_parameter_descriptions, + schema_generator, + strict, + ) + return func_ + + return tool_decorator + else: + # noinspection PyTypeChecker + self.register_function( + func, + None, + name, + retries, + prepare, + docstring_format, + require_parameter_descriptions, + schema_generator, + strict, + ) + return func + + def register_function( + self, + func: ToolFuncEither[AgentDepsT, ToolParams], + takes_ctx: bool | None = None, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> None: + """Register a function as a tool.""" + tool = Tool[AgentDepsT]( + func, + takes_ctx=takes_ctx, + name=name, + max_retries=retries, + prepare=prepare, + docstring_format=docstring_format, + require_parameter_descriptions=require_parameter_descriptions, + schema_generator=schema_generator, + strict=strict, + ) + self.register_tool(tool) + + def register_tool(self, tool: Tool[AgentDepsT]) -> None: + if tool.name in self.tools: + raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}') + if tool.max_retries is None: + tool.max_retries = self.max_retries + self.tools[tool.name] = tool + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + self_for_run = RunToolset(self, ctx) + prepared_for_run = await IndividuallyPreparedToolset(self_for_run, self._prepare_tool_def).prepare_for_run(ctx) + return RunToolset(prepared_for_run, ctx, original=self) + + async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: + tool_name = tool_def.name + ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) + return await self.tools[tool_name].prepare_tool_def(ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [tool.tool_def for tool in self.tools.values()] + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.tools[name].function_schema.validator + + def _max_retries_for_tool(self, name: str) -> int: + tool = self.tools[name] + return tool.max_retries if tool.max_retries is not None else self.max_retries + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.tools[name].function_schema.call(tool_args, ctx) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py new file mode 100644 index 000000000..4ff6b387d --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ( + ToolDefinition, + ToolPrepareFunc, +) +from .mapped import MappedToolset +from .run import RunToolset +from .wrapper import WrapperToolset + + +@dataclass +class IndividuallyPreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a per-tool prepare function.""" + + prepare_func: ToolPrepareFunc[AgentDepsT] + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + + tool_defs: dict[str, ToolDefinition] = {} + name_map: dict[str, str] = {} + for original_tool_def in wrapped_for_run.tool_defs: + original_name = original_tool_def.name + tool_def = await self.prepare_func(ctx, original_tool_def) + if not tool_def: + continue + + new_name = tool_def.name + if new_name in tool_defs: + if new_name != original_name: + raise UserError(f"Renaming tool '{original_name}' to '{new_name}' conflicts with existing tool.") + else: + raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') + name_map[new_name] = original_name + + tool_defs[new_name] = tool_def + + mapped_for_run = await MappedToolset(wrapped_for_run, list(tool_defs.values()), name_map).prepare_for_run(ctx) + return RunToolset(mapped_for_run, ctx, original=self) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/mapped.py b/pydantic_ai_slim/pydantic_ai/toolsets/mapped.py new file mode 100644 index 000000000..bf6fb9e95 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/mapped.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..tools import ( + ToolDefinition, +) +from . import AbstractToolset +from .run import RunToolset +from .wrapper import WrapperToolset + + +@dataclass(init=False) +class MappedToolset(WrapperToolset[AgentDepsT]): + """A toolset that maps the names of the tools it contains.""" + + name_map: dict[str, str] + _tool_defs: list[ToolDefinition] + + def __init__( + self, + wrapped: AbstractToolset[AgentDepsT], + tool_defs: list[ToolDefinition], + name_map: dict[str, str], + ): + super().__init__(wrapped) + self._tool_defs = tool_defs + self.name_map = name_map + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + mapped_for_run = MappedToolset(wrapped_for_run, self._tool_defs, self.name_map) + return RunToolset(mapped_for_run, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self._tool_defs + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return super()._get_tool_args_validator(ctx, self._map_name(name)) + + def _max_retries_for_tool(self, name: str) -> int: + return super()._max_retries_for_tool(self._map_name(name)) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await super().call_tool(ctx, self._map_name(name), tool_args, *args, **kwargs) + + def _map_name(self, name: str) -> str: + return self.name_map.get(name, name) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py new file mode 100644 index 000000000..857b923d1 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from .run import RunToolset +from .wrapper import WrapperToolset + + +@dataclass +class PrefixedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prefixes the names of the tools it contains.""" + + prefix: str + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + prefixed_for_run = PrefixedToolset(wrapped_for_run, self.prefix) + return RunToolset(prefixed_for_run, ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return [replace(tool_def, name=self._prefixed_tool_name(tool_def.name)) for tool_def in super().tool_defs] + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return super()._get_tool_args_validator(ctx, self._unprefixed_tool_name(name)) + + def _max_retries_for_tool(self, name: str) -> int: + return super()._max_retries_for_tool(self._unprefixed_tool_name(name)) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await super().call_tool(ctx, self._unprefixed_tool_name(name), tool_args, *args, **kwargs) + + def _prefixed_tool_name(self, tool_name: str) -> str: + return f'{self.prefix}_{tool_name}' + + def _unprefixed_tool_name(self, tool_name: str) -> str: + full_prefix = f'{self.prefix}_' + if not tool_name.startswith(full_prefix): + raise ValueError(f"Tool name '{tool_name}' does not start with prefix '{full_prefix}'") + return tool_name[len(full_prefix) :] diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py new file mode 100644 index 000000000..2ba2b288a --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ToolsPrepareFunc +from .run import RunToolset +from .wrapper import WrapperToolset + + +@dataclass +class PreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a prepare function.""" + + prepare_func: ToolsPrepareFunc[AgentDepsT] + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + original_tool_defs = wrapped_for_run.tool_defs + prepared_tool_defs = await self.prepare_func(ctx, original_tool_defs) or [] + + original_tool_names = {tool_def.name for tool_def in original_tool_defs} + prepared_tool_names = {tool_def.name for tool_def in prepared_tool_defs} + if len(prepared_tool_names - original_tool_names) > 0: + raise UserError('Prepare function is not allowed to change tool names or add new tools.') + + prepared_for_run = PreparedToolset(wrapped_for_run, self.prepare_func) + return RunToolset(prepared_for_run, ctx, prepared_tool_defs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/processed.py b/pydantic_ai_slim/pydantic_ai/toolsets/processed.py new file mode 100644 index 000000000..084296377 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/processed.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +from collections.abc import Awaitable +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Protocol + +from .._run_context import AgentDepsT, RunContext +from .run import RunToolset +from .wrapper import WrapperToolset + + +class CallToolFunc(Protocol): + """A function protocol that represents a tool call.""" + + def __call__(self, name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any) -> Awaitable[Any]: ... + + +ToolProcessFunc = Callable[ + [ + RunContext[AgentDepsT], + CallToolFunc, + str, + dict[str, Any], + ], + Awaitable[Any], +] + + +@dataclass +class ProcessedToolset(WrapperToolset[AgentDepsT]): + """A toolset that lets the tool call arguments and return value be customized using a process function.""" + + process: ToolProcessFunc[AgentDepsT] + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + wrapped_for_run = await self.wrapped.prepare_for_run(ctx) + processed = ProcessedToolset(wrapped_for_run, self.process) + return RunToolset(processed, ctx) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.process(ctx, partial(self.wrapped.call_tool, ctx), name, tool_args, *args, **kwargs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/run.py b/pydantic_ai_slim/pydantic_ai/toolsets/run.py new file mode 100644 index 000000000..01a991f12 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/run.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from collections.abc import Iterable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass, replace +from typing import Any + +from pydantic import ValidationError + +from pydantic_ai.output import DeferredToolCalls + +from .. import messages as _messages +from .._run_context import AgentDepsT, RunContext +from ..exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior +from ..tools import ToolDefinition +from . import AbstractToolset +from .wrapper import WrapperToolset + + +@dataclass(init=False) +class RunToolset(WrapperToolset[AgentDepsT]): + """A toolset that is frozen for a specific run.""" + + ctx: RunContext[AgentDepsT] + _tool_defs: list[ToolDefinition] + _tool_names: list[str] + _retries: dict[str, int] + _original: AbstractToolset[AgentDepsT] + + def __init__( + self, + wrapped: AbstractToolset[AgentDepsT], + ctx: RunContext[AgentDepsT], + tool_defs: list[ToolDefinition] | None = None, + original: AbstractToolset[AgentDepsT] | None = None, + ): + self.wrapped = wrapped + self.ctx = ctx + self._tool_defs = wrapped.tool_defs if tool_defs is None else tool_defs + self._tool_names = [tool_def.name for tool_def in self._tool_defs] + self._retries = ctx.retries.copy() + self._original = original or wrapped + + @property + def name(self) -> str: + return self.wrapped.name + + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + if ctx == self.ctx: + return self + else: + ctx = replace(ctx, retries=self._retries) + return await self._original.prepare_for_run(ctx) + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self._tool_defs + + @property + def tool_names(self) -> list[str]: + return self._tool_names + + def validate_tool_args( + self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False + ) -> dict[str, Any]: + with self._with_retry(name, ctx) as ctx: + return super().validate_tool_args(ctx, name, args, allow_partial) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + with self._with_retry(name, ctx) as ctx: + try: + output = await super().call_tool(ctx, name, tool_args, *args, **kwargs) + except Exception as e: + raise e + else: + self._retries.pop(name, None) + return output + + def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None: + deferred_calls_and_defs = [ + (part, tool_def) + for part in parts + if isinstance(part, _messages.ToolCallPart) + and (tool_def := self.get_tool_def(part.tool_name)) + and tool_def.kind == 'deferred' + ] + if not deferred_calls_and_defs: + return None + + deferred_calls: list[_messages.ToolCallPart] = [] + deferred_tool_defs: dict[str, ToolDefinition] = {} + for part, tool_def in deferred_calls_and_defs: + deferred_calls.append(part) + deferred_tool_defs[part.tool_name] = tool_def + + return DeferredToolCalls(deferred_calls, deferred_tool_defs) + + @contextmanager + def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunContext[AgentDepsT]]: + try: + if name not in self.tool_names: + if self.tool_names: + msg = f'Available tools: {", ".join(self.tool_names)}' + else: + msg = 'No tools available.' + raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') + + ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={}) + yield ctx + except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e: + try: + max_retries = self._max_retries_for_tool(name) + except Exception: + max_retries = 1 + current_retry = self._retries.get(name, 0) + + if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None: + e = e.__cause__ + + if current_retry == max_retries: + raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e + else: + if ctx.tool_call_id: + if isinstance(e, ValidationError): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.errors(include_url=False, include_context=False), + tool_call_id=ctx.tool_call_id, + ) + e = ToolRetryError(m) + elif isinstance(e, ModelRetry): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.message, + tool_call_id=ctx.tool_call_id, + ) + e = ToolRetryError(m) + + self._retries[name] = current_retry + 1 + raise e diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py new file mode 100644 index 000000000..89d9de656 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from types import TracebackType +from typing import TYPE_CHECKING, Any + +from pydantic_core import SchemaValidator +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from . import AbstractToolset + +if TYPE_CHECKING: + from ..models import Model + from .run import RunToolset + + +@dataclass +class WrapperToolset(AbstractToolset[AgentDepsT], ABC): + """A toolset that wraps another toolset and delegates to it.""" + + wrapped: AbstractToolset[AgentDepsT] + + @property + def name(self) -> str: + return self.wrapped.name + + @property + def tool_name_conflict_hint(self) -> str: + return self.wrapped.tool_name_conflict_hint + + async def __aenter__(self) -> Self: + await self.wrapped.__aenter__() + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None + ) -> bool | None: + return await self.wrapped.__aexit__(exc_type, exc_value, traceback) + + @abstractmethod + async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: + raise NotImplementedError() + + @property + def tool_defs(self) -> list[ToolDefinition]: + return self.wrapped.tool_defs + + def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: + return self.wrapped._get_tool_args_validator(ctx, name) + + def _max_retries_for_tool(self, name: str) -> int: + return self.wrapped._max_retries_for_tool(name) + + async def call_tool( + self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: + return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) + + def set_mcp_sampling_model(self, model: Model) -> None: + self.wrapped.set_mcp_sampling_model(model) + + def __getattr__(self, item: str): + return getattr(self.wrapped, item) # pragma: no cover diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 099152498..02aafd259 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -158,7 +158,7 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel: call_count += 1 raise ModelRetry('Fail') - with pytest.raises(UnexpectedModelBehavior, match=re.escape('Exceeded maximum retries (2) for result validation')): + with pytest.raises(UnexpectedModelBehavior, match=re.escape('Exceeded maximum retries (2) for output validation')): agent.run_sync('Hello', model=TestModel()) assert call_count == 3 @@ -201,7 +201,7 @@ class ResultModel(BaseModel): agent = Agent('test', output_type=ResultModel, retries=2) - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for output validation'): agent.run_sync('Hello', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1})) diff --git a/tests/test_agent.py b/tests/test_agent.py index 51ccbc613..0473477ea 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -46,7 +46,7 @@ from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition -from pydantic_ai.toolset import FunctionToolset +from pydantic_ai.toolsets.function import FunctionToolset from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -955,7 +955,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(upcase)]], ) def test_output_type_multiple_text_output(output_type: OutputSpec[str]): - with pytest.raises(UserError, match='Only one text output is allowed.'): + with pytest.raises(UserError, match='Only one `str` or `TextOutput` is allowed.'): Agent('test', output_type=output_type) @@ -1945,7 +1945,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent(FunctionModel(empty)) with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): agent.run_sync('Hello') assert messages == snapshot( [ diff --git a/tests/test_examples.py b/tests/test_examples.py index 250f29553..cf9f98802 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -39,7 +39,8 @@ from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition -from pydantic_ai.toolset import AbstractToolset +from pydantic_ai.toolsets import AbstractToolset +from pydantic_ai.toolsets.run import RunToolset from .conftest import ClientWithHandler, TestEnv, try_import @@ -269,6 +270,9 @@ async def __aenter__(self) -> MockMCPServer: async def __aexit__(self, *args: Any) -> None: pass + async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]: + return RunToolset(self, ctx) + @property def tool_defs(self) -> list[ToolDefinition]: return [] diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 41230f0aa..a7ed001df 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -28,7 +28,7 @@ from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext -from pydantic_ai.toolset import CallToolFunc +from pydantic_ai.toolsets.processed import CallToolFunc from pydantic_ai.usage import Usage from .conftest import IsDatetime, IsNow, IsStr, try_import @@ -63,7 +63,7 @@ def model(openai_api_key: str) -> Model: @pytest.fixture def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: - return Agent(model, mcp_servers=[mcp_server]) + return Agent(model, toolsets=[mcp_server]) @pytest.fixture @@ -122,7 +122,7 @@ async def process_tool_call( server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call) async with server: - agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), mcp_servers=[server]) + agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), toolsets=[server]) result = await agent.run('Echo with deps set to 42', deps=42) assert result.output == snapshot('{"echo_deps":{"echo":"This is an echo message","deps":42}}') assert called, 'process_tool_call should have been called' @@ -243,7 +243,7 @@ async def test_agent_with_prefix_tool_name(openai_api_key: str): model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) agent = Agent( model, - mcp_servers=[server], + toolsets=[server], ) @agent.tool_plain @@ -260,7 +260,7 @@ def get_none() -> None: # pragma: no cover async def test_agent_with_server_not_running(openai_api_key: str): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - agent = Agent(model, mcp_servers=[server]) + agent = Agent(model, toolsets=[server]) with pytest.raises(UserError, match='MCP server is not running'): await agent.run('What is 0 degrees Celsius in Fahrenheit?') diff --git a/tests/test_streaming.py b/tests/test_streaming.py index e70d642b3..ba81fd9fd 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -276,7 +276,7 @@ async def text_stream(_messages: list[ModelMessage], _: AgentInfo) -> AsyncItera agent = Agent(FunctionModel(stream_function=text_stream), output_type=tuple[str, str]) - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): async with agent.run_stream(''): pass @@ -411,7 +411,7 @@ async def ret_a(x: str) -> str: # pragma: no cover return x with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for output validation'): async with agent.run_stream('hello'): pass diff --git a/tests/test_tools.py b/tests/test_tools.py index 8ce580bf3..fe582f717 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -18,7 +18,7 @@ from pydantic_ai.models.test import TestModel from pydantic_ai.output import DeferredToolCalls, ToolOutput from pydantic_ai.tools import ToolDefinition -from pydantic_ai.toolset import DeferredToolset +from pydantic_ai.toolsets.deferred import DeferredToolset from .conftest import IsStr @@ -1233,7 +1233,7 @@ class MyModel(BaseModel): def test_output_type_deferred_tool_calls_by_itself(): - with pytest.raises(UserError, match='At least one output type must be provided other than DeferredToolCalls.'): + with pytest.raises(UserError, match='At least one output type must be provided other than `DeferredToolCalls`.'): Agent(TestModel(), output_type=DeferredToolCalls) diff --git a/tests/test_toolset.py b/tests/test_toolset.py index c58dce68e..aea2c13c0 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolset.py @@ -9,7 +9,7 @@ from pydantic_ai._run_context import RunContext from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition -from pydantic_ai.toolset import FunctionToolset +from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.usage import Usage pytestmark = pytest.mark.anyio From f660cc1f04f383f1e0fd4e9c6adc7974e4ce62f4 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 2 Jul 2025 00:00:00 +0000 Subject: [PATCH 71/90] Some more tweaks --- pydantic_ai_slim/pydantic_ai/mcp.py | 3 +-- pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ab07ed33e..3743c2cc0 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -163,8 +163,7 @@ async def call_tool( return content[0] if len(content) == 1 else content async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]: - frozen_self = RunToolset(self, ctx, await self.list_tool_defs()) - frozen_toolset = frozen_self + frozen_toolset = RunToolset(self, ctx, await self.list_tool_defs()) if self.process_tool_call: frozen_toolset = await ProcessedToolset(frozen_toolset, self.process_tool_call).prepare_for_run(ctx) if self.tool_prefix: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py index 4ff6b387d..d59a54196 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py @@ -33,7 +33,7 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent new_name = tool_def.name if new_name in tool_defs: if new_name != original_name: - raise UserError(f"Renaming tool '{original_name}' to '{new_name}' conflicts with existing tool.") + raise UserError(f'Renaming tool {original_name!r} to {new_name!r} conflicts with existing tool.') else: raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') name_map[new_name] = original_name From 5ca305ec047d2b52d24af863b0bb008df2f16456 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 2 Jul 2025 00:17:29 +0000 Subject: [PATCH 72/90] Fix docs example --- docs/mcp/client.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 17de670d6..a6efd847a 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -183,7 +183,7 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio, ToolResult from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext -from pydantic_ai.toolset import CallToolFunc +from pydantic_ai.toolsets.processed import CallToolFunc async def process_tool_call( From c5ef5f6c5671f5ac6bd7aebdbee0d257e036cc06 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 2 Jul 2025 22:43:54 +0000 Subject: [PATCH 73/90] Address some feedback --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 6 +- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- pydantic_ai_slim/pydantic_ai/_run_context.py | 5 +- pydantic_ai_slim/pydantic_ai/agent.py | 12 +- pydantic_ai_slim/pydantic_ai/mcp.py | 47 +-- pydantic_ai_slim/pydantic_ai/result.py | 2 +- pydantic_ai_slim/pydantic_ai/tools.py | 4 +- .../pydantic_ai/toolsets/__init__.py | 4 +- ..._prepared.py => _individually_prepared.py} | 4 +- .../toolsets/{mapped.py => _mapped.py} | 4 +- .../pydantic_ai/toolsets/{run.py => _run.py} | 11 +- .../pydantic_ai/toolsets/combined.py | 22 +- .../pydantic_ai/toolsets/deferred.py | 2 +- .../pydantic_ai/toolsets/filtered.py | 2 +- .../pydantic_ai/toolsets/function.py | 8 +- .../pydantic_ai/toolsets/prefixed.py | 2 +- .../pydantic_ai/toolsets/prepared.py | 2 +- .../pydantic_ai/toolsets/processed.py | 4 +- .../pydantic_ai/toolsets/wrapper.py | 6 +- .../test_agent_with_server_not_running.yaml | 391 ++++++++++++++++++ tests/test_agent.py | 4 +- tests/test_examples.py | 4 +- tests/test_mcp.py | 10 +- tests/test_streaming.py | 6 +- tests/test_tools.py | 2 +- 25 files changed, 480 insertions(+), 86 deletions(-) rename pydantic_ai_slim/pydantic_ai/toolsets/{individually_prepared.py => _individually_prepared.py} (96%) rename pydantic_ai_slim/pydantic_ai/toolsets/{mapped.py => _mapped.py} (90%) rename pydantic_ai_slim/pydantic_ai/toolsets/{run.py => _run.py} (93%) create mode 100644 tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 251fc55ac..e3bcee4aa 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -17,7 +17,7 @@ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore from pydantic_ai._utils import is_async_callable, run_in_executor from pydantic_ai.toolsets import AbstractToolset -from pydantic_ai.toolsets.run import RunToolset +from pydantic_ai.toolsets._run import RunToolset from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT @@ -505,7 +505,7 @@ async def _handle_tool_calls( elif deferred_tool_calls := ctx.deps.toolset.get_deferred_tool_calls(tool_calls): if not ctx.deps.output_schema.allows_deferred_tool_calls: raise exceptions.UserError( - 'There are deferred tool calls but DeferredToolCalls is not among output types.' + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' ) final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None) self._next_node = self._handle_final_result(ctx, final_result, output_parts) @@ -586,7 +586,7 @@ async def process_function_tools( # noqa: C901 Also add stub return parts for any other tools that need it. - Because async iterators can't have return values, we use `parts` as an output argument. + Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments. """ run_context = build_run_context(ctx) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 6a22f76e8..ab505e3a4 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -29,7 +29,7 @@ ) from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition from .toolsets import AbstractToolset -from .toolsets.run import RunToolset +from .toolsets._run import RunToolset if TYPE_CHECKING: from .profiles import ModelProfile diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 5f705cd4f..46f7e664b 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import dataclasses +from collections import defaultdict from collections.abc import Sequence from dataclasses import field from typing import TYPE_CHECKING, Generic @@ -31,8 +32,8 @@ class RunContext(Generic[AgentDepsT]): """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" - retries: dict[str, int] = field(default_factory=dict) - """Number of retries for each tool.""" + retries: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int)) + """Number of retries for each tool so far.""" tool_call_id: str | None = None """The ID of the tool call.""" tool_name: str | None = None diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 07891d9eb..6c2586012 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1783,18 +1783,18 @@ def is_end_node( @asynccontextmanager async def run_toolsets( - self, model: models.Model | models.KnownModelName | str | None = None + self, sampling_model: models.Model | models.KnownModelName | str | None = None ) -> AsyncIterator[None]: - """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. + """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] among toolsets so they can be used by the agent. Returns: a context manager to start and shutdown the servers. """ try: - sampling_model: models.Model | None = self._get_model(model) + model: models.Model | None = self._get_model(sampling_model) except exceptions.UserError: # pragma: no cover - sampling_model = None - if sampling_model is not None: # pragma: no branch - self._toolset.set_mcp_sampling_model(sampling_model) + model = None + if model is not None: # pragma: no branch + self._toolset._set_mcp_sampling_model(model) # type: ignore[reportPrivateUsage] async with self._toolset: yield diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 3743c2cc0..37bf71055 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -19,11 +19,10 @@ from pydantic_ai._run_context import RunContext from pydantic_ai.tools import ToolDefinition -from .exceptions import UserError from .toolsets import AbstractToolset +from .toolsets._run import RunToolset from .toolsets.prefixed import PrefixedToolset from .toolsets.processed import ProcessedToolset, ToolProcessFunc -from .toolsets.run import RunToolset try: from mcp import types as mcp_types @@ -104,9 +103,8 @@ async def list_tools(self) -> list[mcp_types.Tool]: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ - if not self.is_running: # pragma: no cover - raise UserError(f'MCP server is not running: {self}') - result = await self._client.list_tools() + async with self: + result = await self._client.list_tools() return result.tools async def call_tool( @@ -134,25 +132,24 @@ async def call_tool( Raises: ModelRetry: If the tool call fails. """ - if not self.is_running: # pragma: no cover - raise UserError(f'MCP server is not running: {self}') - try: - # meta param is not provided by session yet, so build and can send_request directly. - result = await self._client.send_request( - mcp_types.ClientRequest( - mcp_types.CallToolRequest( - method='tools/call', - params=mcp_types.CallToolRequestParams( - name=name, - arguments=tool_args, - _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, - ), - ) - ), - mcp_types.CallToolResult, - ) - except McpError as e: - raise exceptions.ModelRetry(e.error.message) + async with self: + try: + # meta param is not provided by session yet, so build and can send_request directly. + result = await self._client.send_request( + mcp_types.ClientRequest( + mcp_types.CallToolRequest( + method='tools/call', + params=mcp_types.CallToolRequestParams( + name=name, + arguments=tool_args, + _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, + ), + ) + ), + mcp_types.CallToolResult, + ) + except McpError as e: + raise exceptions.ModelRetry(e.error.message) content = [self._map_tool_result_part(part) for part in result.content] @@ -196,7 +193,7 @@ def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_ def _max_retries_for_tool(self, name: str) -> int: return self.max_retries - def set_mcp_sampling_model(self, model: models.Model) -> None: + def _set_mcp_sampling_model(self, model: models.Model) -> None: self.sampling_model = model async def __aenter__(self) -> Self: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 2d1274318..5f732aa86 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -10,7 +10,7 @@ from pydantic import ValidationError from typing_extensions import TypeVar, deprecated, overload -from pydantic_ai.toolsets.run import RunToolset +from pydantic_ai.toolsets._run import RunToolset from . import _utils, exceptions, messages as _messages, models from ._output import ( diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 782f11b09..e79d932a8 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -366,8 +366,8 @@ class ToolDefinition: """The kind of tool: - `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model - `'output'`: a tool that passes through an output value that ends the run - - `'deferred'`: a tool that cannot be executed by Pydantic AI and needs to get a result from the outside. - When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s for each deferred call. + - `'deferred'`: a tool that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via e.g. [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). + When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call. """ __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py index e914d0001..951574b14 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from ..models import Model - from .run import RunToolset + from ._run import RunToolset class AbstractToolset(ABC, Generic[AgentDepsT]): @@ -80,5 +80,5 @@ async def call_tool( ) -> Any: raise NotImplementedError() - def set_mcp_sampling_model(self, model: Model) -> None: + def _set_mcp_sampling_model(self, model: Model) -> None: pass diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py similarity index 96% rename from pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py rename to pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py index d59a54196..88ccbd728 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/individually_prepared.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py @@ -8,8 +8,8 @@ ToolDefinition, ToolPrepareFunc, ) -from .mapped import MappedToolset -from .run import RunToolset +from ._mapped import MappedToolset +from ._run import RunToolset from .wrapper import WrapperToolset diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/mapped.py b/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py similarity index 90% rename from pydantic_ai_slim/pydantic_ai/toolsets/mapped.py rename to pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py index bf6fb9e95..47a7a9e72 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/mapped.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py @@ -10,13 +10,13 @@ ToolDefinition, ) from . import AbstractToolset -from .run import RunToolset +from ._run import RunToolset from .wrapper import WrapperToolset @dataclass(init=False) class MappedToolset(WrapperToolset[AgentDepsT]): - """A toolset that maps the names of the tools it contains.""" + """A toolset that maps renamed tool names to original tool names. Used by `IndividuallyPreparedToolset` as the prepare function may rename a tool.""" name_map: dict[str, str] _tool_defs: list[ToolDefinition] diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/run.py b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py similarity index 93% rename from pydantic_ai_slim/pydantic_ai/toolsets/run.py rename to pydantic_ai_slim/pydantic_ai/toolsets/_run.py index 01a991f12..f04068666 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/run.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict from collections.abc import Iterable, Iterator from contextlib import contextmanager from dataclasses import dataclass, replace @@ -19,12 +20,12 @@ @dataclass(init=False) class RunToolset(WrapperToolset[AgentDepsT]): - """A toolset that is frozen for a specific run.""" + """A toolset that caches the wrapped toolset's tool definitions for a specific run step and handles retries.""" ctx: RunContext[AgentDepsT] _tool_defs: list[ToolDefinition] _tool_names: list[str] - _retries: dict[str, int] + _retries: defaultdict[str, int] _original: AbstractToolset[AgentDepsT] def __init__( @@ -102,19 +103,19 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon try: if name not in self.tool_names: if self.tool_names: - msg = f'Available tools: {", ".join(self.tool_names)}' + msg = f'Available tools: {", ".join(f"{name!r}" for name in self.tool_names)}' else: msg = 'No tools available.' raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') - ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={}) + ctx = replace(ctx, tool_name=name, retry=self._retries[name], retries={}) yield ctx except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e: try: max_retries = self._max_retries_for_tool(name) except Exception: max_retries = 1 - current_retry = self._retries.get(name, 0) + current_retry = self._retries[name] if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None: e = e.__cause__ diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 835718754..bd7784953 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -14,7 +14,7 @@ from ..exceptions import UserError from ..tools import ToolDefinition from . import AbstractToolset -from .run import RunToolset +from ._run import RunToolset if TYPE_CHECKING: from ..models import Model @@ -25,11 +25,13 @@ class CombinedToolset(AbstractToolset[AgentDepsT]): """A toolset that combines multiple toolsets.""" toolsets: list[AbstractToolset[AgentDepsT]] - _exit_stack: AsyncExitStack | None _toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]] + _exit_stack: AsyncExitStack | None + _running_count: int def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): self._exit_stack = None + self._running_count = 0 self.toolsets = list(toolsets) self._toolset_per_tool_name = {} @@ -45,16 +47,18 @@ def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): self._toolset_per_tool_name[name] = toolset async def __aenter__(self) -> Self: - # TODO: running_count thing like in MCPServer? - self._exit_stack = AsyncExitStack() - for toolset in self.toolsets: - await self._exit_stack.enter_async_context(toolset) + if self._running_count == 0: + self._exit_stack = AsyncExitStack() + for toolset in self.toolsets: + await self._exit_stack.enter_async_context(toolset) + self._running_count += 1 return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> bool | None: - if self._exit_stack is not None: + self._running_count -= 1 + if self._running_count <= 0 and self._exit_stack is not None: await self._exit_stack.aclose() self._exit_stack = None return None @@ -88,9 +92,9 @@ async def call_tool( ) -> Any: return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs) - def set_mcp_sampling_model(self, model: Model) -> None: + def _set_mcp_sampling_model(self, model: Model) -> None: for toolset in self.toolsets: - toolset.set_mcp_sampling_model(model) + toolset._set_mcp_sampling_model(model) def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]: try: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py index 1bc69628b..b6b1f8806 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -8,7 +8,7 @@ from .._run_context import AgentDepsT, RunContext from ..tools import ToolDefinition from . import AbstractToolset -from .run import RunToolset +from ._run import RunToolset class DeferredToolset(AbstractToolset[AgentDepsT]): diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py index 6a6789620..336b18a39 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py @@ -6,7 +6,7 @@ from .._run_context import AgentDepsT, RunContext from ..tools import ToolDefinition from . import AbstractToolset -from .individually_prepared import IndividuallyPreparedToolset +from ._individually_prepared import IndividuallyPreparedToolset @dataclass(init=False) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index 0f5261891..ce614ad0c 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -19,13 +19,13 @@ ToolPrepareFunc, ) from . import AbstractToolset -from .individually_prepared import IndividuallyPreparedToolset -from .run import RunToolset +from ._individually_prepared import IndividuallyPreparedToolset +from ._run import RunToolset @dataclass(init=False) class FunctionToolset(AbstractToolset[AgentDepsT]): - """A toolset that functions can be registered to as tools.""" + """A toolset that lets Python functions be used as tools.""" max_retries: int = field(default=1) tools: dict[str, Tool[Any]] = field(default_factory=dict) @@ -188,7 +188,7 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: tool_name = tool_def.name - ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) + ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries[tool_name]) return await self.tools[tool_name].prepare_tool_def(ctx) @property diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py index 857b923d1..9210746ae 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -7,7 +7,7 @@ from .._run_context import AgentDepsT, RunContext from ..tools import ToolDefinition -from .run import RunToolset +from ._run import RunToolset from .wrapper import WrapperToolset diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py index 2ba2b288a..f35d7154d 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py @@ -5,7 +5,7 @@ from .._run_context import AgentDepsT, RunContext from ..exceptions import UserError from ..tools import ToolsPrepareFunc -from .run import RunToolset +from ._run import RunToolset from .wrapper import WrapperToolset diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/processed.py b/pydantic_ai_slim/pydantic_ai/toolsets/processed.py index 084296377..c63854f7b 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/processed.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/processed.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Protocol from .._run_context import AgentDepsT, RunContext -from .run import RunToolset +from ._run import RunToolset from .wrapper import WrapperToolset @@ -29,7 +29,7 @@ def __call__(self, name: str, tool_args: dict[str, Any], *args: Any, **kwargs: A @dataclass class ProcessedToolset(WrapperToolset[AgentDepsT]): - """A toolset that lets the tool call arguments and return value be customized using a process function.""" + """A toolset that lets the tool call arguments and return value be customized using a wrapper function.""" process: ToolProcessFunc[AgentDepsT] diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 89d9de656..01e52928a 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from ..models import Model - from .run import RunToolset + from ._run import RunToolset @dataclass @@ -59,8 +59,8 @@ async def call_tool( ) -> Any: return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) - def set_mcp_sampling_model(self, model: Model) -> None: - self.wrapped.set_mcp_sampling_model(model) + def _set_mcp_sampling_model(self, model: Model) -> None: + self.wrapped._set_mcp_sampling_model(model) def __getattr__(self, item: str): return getattr(self.wrapped, item) # pragma: no cover diff --git a/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml b/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml new file mode 100644 index 000000000..e33e36f96 --- /dev/null +++ b/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml @@ -0,0 +1,391 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2501' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is 0 degrees Celsius in Fahrenheit? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + properties: + foo: + type: string + required: + - foo + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1086' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '420' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{"celsius":0}' + name: celsius_to_fahrenheit + id: call_hS0oexgCNI6TneJuPPuwn9jQ + type: function + created: 1751491994 + id: chatcmpl-BozMoBhgfC5D8QBjkiOwz5OxxrwQK + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 18 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 268 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 286 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2748' + content-type: + - application/json + cookie: + - __cf_bm=JOV7WG2Y48FZrZxdh0IZvA9mCj_ljIN3DhGMuC1pw6M-1751491995-1.0.1.1-zGPrLbzYx7y3iZT28xogbHO1KAIej60kPEwQ8ZxGMxv1r.ICtqI0T8WCnlyUccKfLSXB6ZTNQT05xCma8LSvq2pk4X2eEuSkYC1sPqbuLU8; + _cfuvid=LdoyX0uKYwM98NSSSvySlZAiJHCVHz_1krUGKbWmNHg-1751491995391-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is 0 degrees Celsius in Fahrenheit? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{"celsius":0}' + name: celsius_to_fahrenheit + id: call_hS0oexgCNI6TneJuPPuwn9jQ + type: function + - content: '32.0' + role: tool + tool_call_id: call_hS0oexgCNI6TneJuPPuwn9jQ + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + properties: + foo: + type: string + required: + - foo + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '849' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '520' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: 0 degrees Celsius is 32.0 degrees Fahrenheit. + refusal: null + role: assistant + created: 1751491998 + id: chatcmpl-BozMsevK8quJblNOyNCaDQpdtDwI5 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 300 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 312 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_agent.py b/tests/test_agent.py index 0473477ea..52c913aa6 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2307,7 +2307,7 @@ def another_tool(y: int) -> int: tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ), RetryPromptPart( - content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool", + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsDatetime(), @@ -2394,7 +2394,7 @@ def another_tool(y: int) -> int: # pragma: no cover ), RetryPromptPart( tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool", + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), diff --git a/tests/test_examples.py b/tests/test_examples.py index cf9f98802..03ec1c342 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -6,6 +6,7 @@ import shutil import sys from collections.abc import AsyncIterator, Iterable, Sequence +from contextlib import nullcontext from dataclasses import dataclass from inspect import FrameInfo from io import StringIO @@ -40,7 +41,7 @@ from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets import AbstractToolset -from pydantic_ai.toolsets.run import RunToolset +from pydantic_ai.toolsets._run import RunToolset from .conftest import ClientWithHandler, TestEnv, try_import @@ -263,6 +264,7 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: class MockMCPServer(AbstractToolset[Any]): is_running = True + override_sampling_model = nullcontext async def __aenter__(self) -> MockMCPServer: return self diff --git a/tests/test_mcp.py b/tests/test_mcp.py index a7ed001df..c25873f32 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -257,12 +257,10 @@ def get_none() -> None: # pragma: no cover await agent.run('No conflict') -async def test_agent_with_server_not_running(openai_api_key: str): - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - agent = Agent(model, toolsets=[server]) - with pytest.raises(UserError, match='MCP server is not running'): - await agent.run('What is 0 degrees Celsius in Fahrenheit?') +@pytest.mark.vcr() +async def test_agent_with_server_not_running(agent: Agent, allow_model_requests: None): + result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') + assert result.output == snapshot('0 degrees Celsius is 32.0 degrees Fahrenheit.') async def test_log_level_unset(run_context: RunContext[int]): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index ba81fd9fd..cc5a0c4b3 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -624,7 +624,7 @@ def another_tool(y: int) -> int: tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), RetryPromptPart( - content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool", + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), @@ -737,7 +737,7 @@ def another_tool(y: int) -> int: # pragma: no cover part_kind='tool-return', ), RetryPromptPart( - content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool", + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), @@ -992,7 +992,7 @@ def known_tool(x: int) -> int: ), FunctionToolResultEvent( result=RetryPromptPart( - content="Unknown tool name: 'unknown_tool'. Available tools: known_tool", + content="Unknown tool name: 'unknown_tool'. Available tools: 'known_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), diff --git a/tests/test_tools.py b/tests/test_tools.py index fe582f717..d7f9840bc 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1155,7 +1155,7 @@ async def prepare_tool_defs( ctx: RunContext[None], tool_defs: list[ToolDefinition] ) -> Union[list[ToolDefinition], None]: nonlocal prepare_tools_retries - retry = ctx.retries.get('infinite_retry_tool', 0) + retry = ctx.retries['infinite_retry_tool'] prepare_tools_retries.append(retry) return tool_defs From acddb8d9cd58c2fa9832a21e8f86383c0a517f7f Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 2 Jul 2025 23:34:35 +0000 Subject: [PATCH 74/90] Add sampling_model to Agent __init__, iter, run (etc), and override, pass sampling_model to MCPServer through RunContext, and make Agent an async contextmanager instead of run_toolsets --- docs/mcp/client.md | 19 ++-- mcp-run-python/README.md | 2 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 + pydantic_ai_slim/pydantic_ai/_run_context.py | 2 + pydantic_ai_slim/pydantic_ai/agent.py | 101 ++++++++++++++---- pydantic_ai_slim/pydantic_ai/mcp.py | 57 +++++----- .../pydantic_ai/toolsets/__init__.py | 7 -- .../pydantic_ai/toolsets/combined.py | 13 +-- .../pydantic_ai/toolsets/wrapper.py | 8 -- tests/test_examples.py | 2 + tests/test_mcp.py | 28 ++--- tests/test_toolset.py | 10 +- 12 files changed, 155 insertions(+), 96 deletions(-) diff --git a/docs/mcp/client.md b/docs/mcp/client.md index a6efd847a..33bd4e196 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -29,7 +29,7 @@ Examples of both are shown below; [mcp-run-python](run-python.md) is used as the [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. !!! note - [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before calling [`agent.run_toolsets()`][pydantic_ai.Agent.run_toolsets]. Running the server is not managed by PydanticAI. + [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI. The name "HTTP" is used since this implementation will be adapted in future to use the new [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. @@ -51,7 +51,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): - async with agent.run_toolsets(): # (3)! + async with agent: # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -92,9 +92,8 @@ Will display as follows: !!! note [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be - running and accepting HTTP connections before calling - [`agent.run_toolsets()`][pydantic_ai.Agent.run_toolsets]. Running the server is not - managed by PydanticAI. + running and accepting HTTP connections before running the agent. Running the server is not + managed by Pydantic AI. Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. @@ -121,7 +120,7 @@ server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): - async with agent.run_toolsets(): # (3)! + async with agent: # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -138,7 +137,7 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class. !!! note - When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`agent.run_toolsets()`][pydantic_ai.Agent.run_toolsets] context manager is responsible for starting and stopping the server. + When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager is responsible for starting and stopping the server. ```python {title="mcp_stdio_client.py" py="3.10"} from pydantic_ai import Agent @@ -160,7 +159,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_toolsets(): + async with agent: result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -205,7 +204,7 @@ agent = Agent( async def main(): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Echo with deps set to 42', deps=42) print(result.output) #> {"echo_deps":{"echo":"This is an echo message","deps":42}} @@ -364,7 +363,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Create an image of a robot in a punk style.') print(result.output) #> Image file written to robot_punk.svg. diff --git a/mcp-run-python/README.md b/mcp-run-python/README.md index 93bbfc87f..edd84ddb8 100644 --- a/mcp-run-python/README.md +++ b/mcp-run-python/README.md @@ -56,7 +56,7 @@ agent = Agent('claude-3-5-haiku-latest', toolsets=[server]) async def main(): - async with agent.run_toolsets(): + async with agent: result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025.w diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e3bcee4aa..a4ac696cf 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -115,6 +115,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): history_processors: Sequence[HistoryProcessor[DepsT]] toolset: RunToolset[DepsT] + sampling_model: models.Model tracer: Tracer instrumentation_settings: InstrumentationSettings | None = None @@ -561,6 +562,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT deps=ctx.deps.user_deps, model=ctx.deps.model, usage=ctx.state.usage, + sampling_model=ctx.deps.sampling_model, prompt=ctx.deps.prompt, messages=ctx.state.message_history, run_step=ctx.state.run_step, diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 46f7e664b..92d8ab32a 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -28,6 +28,8 @@ class RunContext(Generic[AgentDepsT]): """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" + sampling_model: Model + """The model used for MCP sampling.""" prompt: str | Sequence[_messages.UserContent] | None = None """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 8c71a74c7..475c7ab03 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -165,6 +165,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) + _sampling_model: models.Model | models.KnownModelName | str | None = dataclasses.field(repr=False) + + _running_count: int = dataclasses.field(repr=False) + _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) @overload def __init__( @@ -190,6 +194,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> None: ... @overload @@ -221,6 +226,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> None: ... @overload @@ -250,6 +256,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> None: ... def __init__( @@ -276,6 +283,7 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Any, ): """Create an agent. @@ -324,6 +332,7 @@ def __init__( history_processors: Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. + sampling_model: The model to use for MCP sampling, if not provided, the agent's model will be used. """ if model is None or defer_model_check: self.model = model @@ -424,8 +433,16 @@ def __init__( self.history_processors = history_processors or [] + self._sampling_model = sampling_model + self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) + self._override_sampling_model: ContextVar[_utils.Option[models.Model]] = ContextVar( + '_override_sampling_model', default=None + ) + + self._exit_stack = None + self._running_count = 0 @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: @@ -446,6 +463,7 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -462,6 +480,7 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -479,6 +498,7 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -494,6 +514,7 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -525,6 +546,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. + sampling_model: Optional model to use for MCP sampling. Returns: The result of the run. @@ -550,6 +572,7 @@ async def main(): usage_limits=usage_limits, usage=usage, toolsets=toolsets, + sampling_model=sampling_model, ) as agent_run: async for _ in agent_run: pass @@ -571,6 +594,7 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -588,6 +612,7 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -606,6 +631,7 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager @@ -622,6 +648,7 @@ async def iter( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -697,6 +724,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. + sampling_model: Optional model to use for MCP sampling. Returns: The result of the run. @@ -706,6 +734,8 @@ async def main(): model_used = self._get_model(model) del model + sampling_model_used = self._get_sampling_model(sampling_model) or model_used + if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') @@ -752,6 +782,7 @@ async def main(): deps=deps, model=model_used, usage=usage, + sampling_model=sampling_model_used, prompt=user_prompt, messages=state.message_history, run_step=state.run_step, @@ -813,6 +844,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: output_validators=output_validators, history_processors=self.history_processors, toolset=run_toolset, + sampling_model=sampling_model_used, tracer=tracer, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, @@ -884,6 +916,7 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -900,6 +933,7 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -917,6 +951,7 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -932,6 +967,7 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -962,6 +998,7 @@ def run_sync( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. + sampling_model: Optional model to use for MCP sampling. Returns: The result of the run. @@ -989,6 +1026,7 @@ def run_sync( usage=usage, infer_name=False, toolsets=toolsets, + sampling_model=sampling_model, ) ) @@ -1054,6 +1092,7 @@ async def run_stream( # noqa C901 usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -1082,6 +1121,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. + sampling_model: Optional model to use for MCP sampling. Returns: The result of the run. @@ -1113,6 +1153,7 @@ async def main(): usage=usage, infer_name=False, toolsets=toolsets, + sampling_model=sampling_model, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node @@ -1203,6 +1244,7 @@ def override( *, deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + sampling_model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies and model. @@ -1212,6 +1254,7 @@ def override( Args: deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. + sampling_model: The model to use for MCP sampling instead of the sampling model passed to the agent run. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) @@ -1223,6 +1266,11 @@ def override( else: model_token = None + if _utils.is_set(sampling_model): + sampling_model_token = self._override_sampling_model.set(_utils.Some(models.infer_model(sampling_model))) + else: + sampling_model_token = None + try: yield finally: @@ -1230,6 +1278,8 @@ def override( self._override_deps.reset(deps_token) if model_token is not None: self._override_model.reset(model_token) + if sampling_model_token is not None: + self._override_sampling_model.reset(sampling_model_token) @overload def instructions( @@ -1696,6 +1746,19 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: else: return deps + def _get_sampling_model( + self, sampling_model: models.Model | models.KnownModelName | str | None + ) -> models.Model | None: + """Get the sampling model for a run.""" + if some_sampling_model := self._override_sampling_model.get(): + return some_sampling_model.value + elif sampling_model is not None: + return models.infer_model(sampling_model) + elif self._sampling_model is not None: + return models.infer_model(self._sampling_model) + else: + return None + def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. @@ -1781,27 +1844,24 @@ def is_end_node( """ return isinstance(node, End) - @asynccontextmanager - async def run_toolsets( - self, sampling_model: models.Model | models.KnownModelName | str | None = None - ) -> AsyncIterator[None]: - """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] among toolsets so they can be used by the agent. + async def __aenter__(self) -> Self: + if self._running_count == 0: + self._exit_stack = AsyncExitStack() + await self._exit_stack.enter_async_context(self._toolset) + self._running_count += 1 + return self - Returns: a context manager to start and shutdown the servers. - """ - try: - model: models.Model | None = self._get_model(sampling_model) - except exceptions.UserError: # pragma: no cover - model = None - - async with AsyncExitStack() as exit_stack: - if model is not None: # pragma: no branch - exit_stack.enter_context(self._toolset.override_sampling_model(model)) - await exit_stack.enter_async_context(self._toolset) - yield + async def __aexit__(self, *args: Any) -> bool | None: + self._running_count -= 1 + if self._running_count <= 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + return None @asynccontextmanager - @deprecated('`run_mcp_servers` is deprecated, use `run_toolsets` instead.') + @deprecated( + '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set an MCP sampling model, use `with agent.override(sampling_model=...)`.' + ) async def run_mcp_servers( self, model: models.Model | models.KnownModelName | str | None = None ) -> AsyncIterator[None]: @@ -1809,8 +1869,9 @@ async def run_mcp_servers( Returns: a context manager to start and shutdown the servers. """ - async with self.run_toolsets(model): - yield + with self.override(sampling_model=model or _utils.UNSET): + async with self: + yield def to_a2a( self, diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 4e74e08db..c62137d65 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -4,7 +4,7 @@ import functools from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Iterator, Sequence -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager, nullcontext from contextvars import ContextVar from dataclasses import dataclass from pathlib import Path @@ -120,7 +120,7 @@ async def list_tools(self) -> list[mcp_types.Tool]: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ - async with self: + async with self: # Ensure server is running result = await self._client.list_tools() return result.tools @@ -149,24 +149,28 @@ async def call_tool( Raises: ModelRetry: If the tool call fails. """ - async with self: - try: - # meta param is not provided by session yet, so build and can send_request directly. - result = await self._client.send_request( - mcp_types.ClientRequest( - mcp_types.CallToolRequest( - method='tools/call', - params=mcp_types.CallToolRequestParams( - name=name, - arguments=tool_args, - _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, - ), - ) - ), - mcp_types.CallToolResult, - ) - except McpError as e: - raise exceptions.ModelRetry(e.error.message) + sampling_contextmanager = ( + nullcontext() if self._get_sampling_model() else self.override_sampling_model(ctx.sampling_model) + ) + with sampling_contextmanager: + async with self: # Ensure server is running + try: + # meta param is not provided by session yet, so build and can send_request directly. + result = await self._client.send_request( + mcp_types.ClientRequest( + mcp_types.CallToolRequest( + method='tools/call', + params=mcp_types.CallToolRequestParams( + name=name, + arguments=tool_args, + _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, + ), + ) + ), + mcp_types.CallToolResult, + ) + except McpError as e: + raise exceptions.ModelRetry(e.error.message) content = [self._map_tool_result_part(part) for part in result.content] @@ -241,11 +245,14 @@ async def __aexit__( if self._running_count <= 0: await self._exit_stack.aclose() + def _get_sampling_model(self) -> models.Model | None: + return self._override_sampling_model.get() or self.sampling_model + async def _sampling_callback( self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: """MCP sampling callback.""" - sampling_model = self._override_sampling_model.get() or self.sampling_model + sampling_model = self._get_sampling_model() if sampling_model is None: raise ValueError('Sampling model is not set') # pragma: no cover @@ -336,7 +343,7 @@ class MCPServerStdio(MCPServer): agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_toolsets(): # (2)! + async with agent: # (2)! ... ``` @@ -574,7 +581,7 @@ class MCPServerSSE(_MCPServerHTTP): agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_toolsets(): # (2)! + async with agent: # (2)! ... ``` @@ -608,7 +615,7 @@ class MCPServerHTTP(MCPServerSSE): agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_toolsets(): # (2)! + async with agent: # (2)! ... ``` @@ -637,7 +644,7 @@ class MCPServerStreamableHTTP(_MCPServerHTTP): agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_toolsets(): # (2)! + async with agent: # (2)! ... ``` """ diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py index b07e87541..66caa1678 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -1,8 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Iterator -from contextlib import contextmanager from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, Literal @@ -13,7 +11,6 @@ from ..tools import ToolDefinition if TYPE_CHECKING: - from ..models import Model from ._run import RunToolset @@ -81,7 +78,3 @@ async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any ) -> Any: raise NotImplementedError() - - @contextmanager - def override_sampling_model(self, model: Model) -> Iterator[None]: - yield diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index df6af0670..738f82a3c 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -1,8 +1,8 @@ from __future__ import annotations import asyncio -from collections.abc import Iterator, Sequence -from contextlib import AsyncExitStack, ExitStack, contextmanager +from collections.abc import Sequence +from contextlib import AsyncExitStack from dataclasses import dataclass from types import TracebackType from typing import TYPE_CHECKING, Any @@ -17,7 +17,7 @@ from ._run import RunToolset if TYPE_CHECKING: - from ..models import Model + pass @dataclass(init=False) @@ -92,13 +92,6 @@ async def call_tool( ) -> Any: return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs) - @contextmanager - def override_sampling_model(self, model: Model) -> Iterator[None]: - with ExitStack() as exit_stack: - for toolset in self.toolsets: - exit_stack.enter_context(toolset.override_sampling_model(model)) - yield - def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]: try: return self._toolset_per_tool_name[name] diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 856a68876..354de8ebc 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -1,8 +1,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Iterator -from contextlib import contextmanager from dataclasses import dataclass from types import TracebackType from typing import TYPE_CHECKING, Any @@ -15,7 +13,6 @@ from . import AbstractToolset if TYPE_CHECKING: - from ..models import Model from ._run import RunToolset @@ -61,10 +58,5 @@ async def call_tool( ) -> Any: return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) - @contextmanager - def override_sampling_model(self, model: Model) -> Iterator[None]: - with self.wrapped.override_sampling_model(model): - yield - def __getattr__(self, item: str): return getattr(self.wrapped, item) # pragma: no cover diff --git a/tests/test_examples.py b/tests/test_examples.py index 5f30d4b87..03ec1c342 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -6,6 +6,7 @@ import shutil import sys from collections.abc import AsyncIterator, Iterable, Sequence +from contextlib import nullcontext from dataclasses import dataclass from inspect import FrameInfo from io import StringIO @@ -263,6 +264,7 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: class MockMCPServer(AbstractToolset[Any]): is_running = True + override_sampling_model = nullcontext async def __aenter__(self) -> MockMCPServer: return self diff --git a/tests/test_mcp.py b/tests/test_mcp.py index c25873f32..bb8181866 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -68,7 +68,7 @@ def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: @pytest.fixture def run_context(model: Model) -> RunContext[int]: - return RunContext(deps=0, model=model, usage=Usage()) + return RunContext(deps=0, model=model, usage=Usage(), sampling_model=model) async def test_stdio_server(run_context: RunContext[int]): @@ -151,7 +151,7 @@ def test_sse_server_with_header_and_timeout(): @pytest.mark.vcr() async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') assert result.output == snapshot('0 degrees Celsius is equal to 32 degrees Fahrenheit.') assert result.all_messages() == snapshot( @@ -228,7 +228,7 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_toolsets(): + async with agent: with pytest.raises( UserError, match=re.escape( @@ -251,7 +251,7 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_toolsets(): + async with agent: # This means that we passed the _prepare_request_parameters check and there is no conflict in the tool name with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): await agent.run('No conflict') @@ -285,7 +285,7 @@ async def test_log_level_set(run_context: RunContext[int]): @pytest.mark.vcr() async def test_tool_returning_str(allow_model_requests: None, agent: Agent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('What is the weather in Mexico City?') assert result.output == snapshot( 'The weather in Mexico City is currently sunny with a temperature of 26 degrees Celsius.' @@ -364,7 +364,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_text_resource(allow_model_requests: None, agent: Agent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Get me the product name') assert result.output == snapshot('The product name is "PydanticAI".') assert result.all_messages() == snapshot( @@ -437,7 +437,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A @pytest.mark.vcr() async def test_tool_returning_image_resource(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Get me the image resource') assert result.output == snapshot( 'This is an image of a sliced kiwi with a vibrant green interior and black seeds.' @@ -520,7 +520,7 @@ async def test_tool_returning_audio_resource( allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str ): model = GoogleModel('gemini-2.5-pro-preview-03-25', provider=GoogleProvider(api_key=gemini_api_key)) - async with agent.run_toolsets(): + async with agent: result = await agent.run("What's the content of the audio resource?", model=model) assert result.output == snapshot('The audio resource contains a voice saying "Hello, my name is Marcelo."') assert result.all_messages() == snapshot( @@ -571,7 +571,7 @@ async def test_tool_returning_audio_resource( @pytest.mark.vcr() async def test_tool_returning_image(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Get me an image') assert result.output == snapshot('Here is an image of a sliced kiwi on a white background.') assert result.all_messages() == snapshot( @@ -651,7 +651,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im @pytest.mark.vcr() async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Get me a dict, respond on one line') assert result.output == snapshot('{"foo":"bar","baz":123}') assert result.all_messages() == snapshot( @@ -718,7 +718,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_error(allow_model_requests: None, agent: Agent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Get me an error, pass False as a value, unless the tool tells you otherwise') assert result.output == snapshot( 'I called the tool with the correct parameter, and it returned: "This is not an error."' @@ -832,7 +832,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_none(allow_model_requests: None, agent: Agent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Call the none tool and say Hello') assert result.output == snapshot('Hello! How can I assist you today?') assert result.all_messages() == snapshot( @@ -899,7 +899,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_multiple_items(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_toolsets(): + async with agent: result = await agent.run('Get me multiple items and summarize in one sentence') assert result.output == snapshot( 'The data includes two strings, a dictionary with a key-value pair, and an image of a sliced kiwi.' @@ -1017,7 +1017,7 @@ async def test_mcp_server_raises_mcp_error( ) -> None: mcp_error = McpError(error=ErrorData(code=400, message='Test MCP error conversion')) - async with agent.run_toolsets(): + async with agent: with patch.object( mcp_server._client, # pyright: ignore[reportPrivateUsage] 'send_request', diff --git a/tests/test_toolset.py b/tests/test_toolset.py index aea2c13c0..4122df05b 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolset.py @@ -18,7 +18,15 @@ def build_run_context(deps: T) -> RunContext[T]: - return RunContext(deps=deps, model=TestModel(), usage=Usage(), prompt=None, messages=[], run_step=0) + return RunContext( + deps=deps, + model=TestModel(), + usage=Usage(), + sampling_model=TestModel(), + prompt=None, + messages=[], + run_step=0, + ) async def test_function_toolset_prepare_for_run(): From 89fc266e84266e335f37993b9b327e93bb8af6eb Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 2 Jul 2025 23:37:51 +0000 Subject: [PATCH 75/90] Turn RunContext.retries from a defaultdict into a dict again as the 0 being stored on read broke a test --- pydantic_ai_slim/pydantic_ai/_run_context.py | 3 +-- pydantic_ai_slim/pydantic_ai/toolsets/_run.py | 7 +++---- pydantic_ai_slim/pydantic_ai/toolsets/function.py | 2 +- tests/test_tools.py | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 92d8ab32a..2eb50742f 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -1,7 +1,6 @@ from __future__ import annotations as _annotations import dataclasses -from collections import defaultdict from collections.abc import Sequence from dataclasses import field from typing import TYPE_CHECKING, Generic @@ -34,7 +33,7 @@ class RunContext(Generic[AgentDepsT]): """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" - retries: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int)) + retries: dict[str, int] = field(default_factory=dict) """Number of retries for each tool so far.""" tool_call_id: str | None = None """The ID of the tool call.""" diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_run.py b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py index f04068666..411ee35c9 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/_run.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections import defaultdict from collections.abc import Iterable, Iterator from contextlib import contextmanager from dataclasses import dataclass, replace @@ -25,7 +24,7 @@ class RunToolset(WrapperToolset[AgentDepsT]): ctx: RunContext[AgentDepsT] _tool_defs: list[ToolDefinition] _tool_names: list[str] - _retries: defaultdict[str, int] + _retries: dict[str, int] _original: AbstractToolset[AgentDepsT] def __init__( @@ -108,14 +107,14 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon msg = 'No tools available.' raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') - ctx = replace(ctx, tool_name=name, retry=self._retries[name], retries={}) + ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={}) yield ctx except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e: try: max_retries = self._max_retries_for_tool(name) except Exception: max_retries = 1 - current_retry = self._retries[name] + current_retry = self._retries.get(name, 0) if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None: e = e.__cause__ diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index ce614ad0c..fbc60f8b0 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -188,7 +188,7 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: tool_name = tool_def.name - ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries[tool_name]) + ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) return await self.tools[tool_name].prepare_tool_def(ctx) @property diff --git a/tests/test_tools.py b/tests/test_tools.py index d7f9840bc..fe582f717 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1155,7 +1155,7 @@ async def prepare_tool_defs( ctx: RunContext[None], tool_defs: list[ToolDefinition] ) -> Union[list[ToolDefinition], None]: nonlocal prepare_tools_retries - retry = ctx.retries['infinite_retry_tool'] + retry = ctx.retries.get('infinite_retry_tool', 0) prepare_tools_retries.append(retry) return tool_defs From 7e3331b3ec9db83202dfd2ae2170de83e9e01f73 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 2 Jul 2025 23:39:51 +0000 Subject: [PATCH 76/90] Remove unnecessary if TYPE_CHECKING --- pydantic_ai_slim/pydantic_ai/toolsets/combined.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 738f82a3c..8930c36c2 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -5,7 +5,7 @@ from contextlib import AsyncExitStack from dataclasses import dataclass from types import TracebackType -from typing import TYPE_CHECKING, Any +from typing import Any from pydantic_core import SchemaValidator from typing_extensions import Self @@ -16,9 +16,6 @@ from . import AbstractToolset from ._run import RunToolset -if TYPE_CHECKING: - pass - @dataclass(init=False) class CombinedToolset(AbstractToolset[AgentDepsT]): From ebf6f4057273ee72b7641117e1c43c1bdb41eb5b Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 3 Jul 2025 18:18:15 +0000 Subject: [PATCH 77/90] Remove Agent sampling_model field (and method argument) in favor of Agent.set_mcp_sampling_model --- docs/mcp/client.md | 1 + pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 - pydantic_ai_slim/pydantic_ai/_run_context.py | 2 - pydantic_ai_slim/pydantic_ai/agent.py | 96 ++++++------------- pydantic_ai_slim/pydantic_ai/mcp.py | 70 +++++--------- .../pydantic_ai/toolsets/__init__.py | 5 +- .../pydantic_ai/toolsets/combined.py | 6 +- .../pydantic_ai/toolsets/wrapper.py | 5 +- tests/test_examples.py | 2 - tests/test_mcp.py | 2 +- tests/test_toolset.py | 1 - 11 files changed, 68 insertions(+), 124 deletions(-) diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 33bd4e196..2c611b927 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -364,6 +364,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): async with agent: + agent.set_mcp_sampling_model() result = await agent.run('Create an image of a robot in a punk style.') print(result.output) #> Image file written to robot_punk.svg. diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index a4ac696cf..e3bcee4aa 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -115,7 +115,6 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): history_processors: Sequence[HistoryProcessor[DepsT]] toolset: RunToolset[DepsT] - sampling_model: models.Model tracer: Tracer instrumentation_settings: InstrumentationSettings | None = None @@ -562,7 +561,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT deps=ctx.deps.user_deps, model=ctx.deps.model, usage=ctx.state.usage, - sampling_model=ctx.deps.sampling_model, prompt=ctx.deps.prompt, messages=ctx.state.message_history, run_step=ctx.state.run_step, diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index 2eb50742f..528aa9100 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -27,8 +27,6 @@ class RunContext(Generic[AgentDepsT]): """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" - sampling_model: Model - """The model used for MCP sampling.""" prompt: str | Sequence[_messages.UserContent] | None = None """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 475c7ab03..85b787cda 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -165,7 +165,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) - _sampling_model: models.Model | models.KnownModelName | str | None = dataclasses.field(repr=False) _running_count: int = dataclasses.field(repr=False) _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) @@ -194,7 +193,6 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> None: ... @overload @@ -226,7 +224,6 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> None: ... @overload @@ -256,7 +253,6 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> None: ... def __init__( @@ -283,7 +279,6 @@ def __init__( end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Any, ): """Create an agent. @@ -332,7 +327,6 @@ def __init__( history_processors: Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. - sampling_model: The model to use for MCP sampling, if not provided, the agent's model will be used. """ if model is None or defer_model_check: self.model = model @@ -433,13 +427,8 @@ def __init__( self.history_processors = history_processors or [] - self._sampling_model = sampling_model - self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) - self._override_sampling_model: ContextVar[_utils.Option[models.Model]] = ContextVar( - '_override_sampling_model', default=None - ) self._exit_stack = None self._running_count = 0 @@ -463,7 +452,6 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -480,7 +468,6 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -498,7 +485,6 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -514,7 +500,6 @@ async def run( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -545,8 +530,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. - sampling_model: Optional model to use for MCP sampling. + toolsets: Optional toolsets to use for this run instead Returns: The result of the run. @@ -572,7 +556,6 @@ async def main(): usage_limits=usage_limits, usage=usage, toolsets=toolsets, - sampling_model=sampling_model, ) as agent_run: async for _ in agent_run: pass @@ -594,7 +577,6 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -612,7 +594,6 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -631,7 +612,6 @@ def iter( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager @@ -648,7 +628,6 @@ async def iter( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -723,8 +702,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. - sampling_model: Optional model to use for MCP sampling. + toolsets: Optional toolsets to use for this run instead Returns: The result of the run. @@ -734,8 +712,6 @@ async def main(): model_used = self._get_model(model) del model - sampling_model_used = self._get_sampling_model(sampling_model) or model_used - if 'result_type' in _deprecated_kwargs: # pragma: no cover if output_type is not str: raise TypeError('`result_type` and `output_type` cannot be set at the same time.') @@ -782,7 +758,6 @@ async def main(): deps=deps, model=model_used, usage=usage, - sampling_model=sampling_model_used, prompt=user_prompt, messages=state.message_history, run_step=state.run_step, @@ -844,7 +819,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: output_validators=output_validators, history_processors=self.history_processors, toolset=run_toolset, - sampling_model=sampling_model_used, tracer=tracer, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, @@ -916,7 +890,6 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -933,7 +906,6 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -951,7 +923,6 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -967,7 +938,6 @@ def run_sync( usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -997,8 +967,7 @@ def run_sync( usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. - sampling_model: Optional model to use for MCP sampling. + toolsets: Optional toolsets to use for this run instead Returns: The result of the run. @@ -1026,7 +995,6 @@ def run_sync( usage=usage, infer_name=False, toolsets=toolsets, - sampling_model=sampling_model, ) ) @@ -1092,7 +1060,6 @@ async def run_stream( # noqa C901 usage: _usage.Usage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - sampling_model: models.Model | models.KnownModelName | str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -1120,8 +1087,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent. - sampling_model: Optional model to use for MCP sampling. + toolsets: Optional toolsets to use for this run instead Returns: The result of the run. @@ -1153,7 +1119,6 @@ async def main(): usage=usage, infer_name=False, toolsets=toolsets, - sampling_model=sampling_model, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node @@ -1244,7 +1209,6 @@ def override( *, deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, - sampling_model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent dependencies and model. @@ -1254,7 +1218,6 @@ def override( Args: deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. - sampling_model: The model to use for MCP sampling instead of the sampling model passed to the agent run. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) @@ -1266,11 +1229,6 @@ def override( else: model_token = None - if _utils.is_set(sampling_model): - sampling_model_token = self._override_sampling_model.set(_utils.Some(models.infer_model(sampling_model))) - else: - sampling_model_token = None - try: yield finally: @@ -1278,8 +1236,6 @@ def override( self._override_deps.reset(deps_token) if model_token is not None: self._override_model.reset(model_token) - if sampling_model_token is not None: - self._override_sampling_model.reset(sampling_model_token) @overload def instructions( @@ -1746,19 +1702,6 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: else: return deps - def _get_sampling_model( - self, sampling_model: models.Model | models.KnownModelName | str | None - ) -> models.Model | None: - """Get the sampling model for a run.""" - if some_sampling_model := self._override_sampling_model.get(): - return some_sampling_model.value - elif sampling_model is not None: - return models.infer_model(sampling_model) - elif self._sampling_model is not None: - return models.infer_model(self._sampling_model) - else: - return None - def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. @@ -1858,9 +1801,27 @@ async def __aexit__(self, *args: Any) -> bool | None: self._exit_stack = None return None + def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None: + """Set the sampling model on all MCP servers registered with the agent. + + If no sampling model is provided, the agent's model will be used. + """ + try: + sampling_model = models.infer_model(model) if model else self._get_model(None) + except exceptions.UserError as e: + raise exceptions.UserError('No sampling model provided and no model set on the agent.') from e + + from .mcp import MCPServer + + def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None: + if isinstance(toolset, MCPServer): + toolset.sampling_model = sampling_model + + self._toolset.accept(_set_sampling_model) + @asynccontextmanager @deprecated( - '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set an MCP sampling model, use `with agent.override(sampling_model=...)`.' + '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model(...)`.' ) async def run_mcp_servers( self, model: models.Model | models.KnownModelName | str | None = None @@ -1869,9 +1830,14 @@ async def run_mcp_servers( Returns: a context manager to start and shutdown the servers. """ - with self.override(sampling_model=model or _utils.UNSET): - async with self: - yield + try: + self.set_mcp_sampling_model(model) + except exceptions.UserError: + if model is not None: + raise + + async with self: + yield def to_a2a( self, diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index c62137d65..73ceffc2b 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -3,9 +3,8 @@ import base64 import functools from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Iterator, Sequence -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager, nullcontext -from contextvars import ContextVar +from collections.abc import AsyncIterator, Sequence +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from dataclasses import dataclass from pathlib import Path from types import TracebackType @@ -70,22 +69,6 @@ class MCPServer(AbstractToolset[Any], ABC): _exit_stack: AsyncExitStack sampling_model: models.Model | None = None - def __post_init__(self): - self._override_sampling_model: ContextVar[models.Model | None] = ContextVar( - '_override_sampling_model', default=None - ) - - @contextmanager - def override_sampling_model( - self, - model: models.Model, - ) -> Iterator[None]: - token = self._override_sampling_model.set(model) - try: - yield - finally: - self._override_sampling_model.reset(token) - @abstractmethod @asynccontextmanager async def client_streams( @@ -149,28 +132,23 @@ async def call_tool( Raises: ModelRetry: If the tool call fails. """ - sampling_contextmanager = ( - nullcontext() if self._get_sampling_model() else self.override_sampling_model(ctx.sampling_model) - ) - with sampling_contextmanager: - async with self: # Ensure server is running - try: - # meta param is not provided by session yet, so build and can send_request directly. - result = await self._client.send_request( - mcp_types.ClientRequest( - mcp_types.CallToolRequest( - method='tools/call', - params=mcp_types.CallToolRequestParams( - name=name, - arguments=tool_args, - _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, - ), - ) - ), - mcp_types.CallToolResult, - ) - except McpError as e: - raise exceptions.ModelRetry(e.error.message) + async with self: # Ensure server is running + try: + result = await self._client.send_request( + mcp_types.ClientRequest( + mcp_types.CallToolRequest( + method='tools/call', + params=mcp_types.CallToolRequestParams( + name=name, + arguments=tool_args, + _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, + ), + ) + ), + mcp_types.CallToolResult, + ) + except McpError as e: + raise exceptions.ModelRetry(e.error.message) content = [self._map_tool_result_part(part) for part in result.content] @@ -245,15 +223,11 @@ async def __aexit__( if self._running_count <= 0: await self._exit_stack.aclose() - def _get_sampling_model(self) -> models.Model | None: - return self._override_sampling_model.get() or self.sampling_model - async def _sampling_callback( self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: """MCP sampling callback.""" - sampling_model = self._get_sampling_model() - if sampling_model is None: + if self.sampling_model is None: raise ValueError('Sampling model is not set') # pragma: no cover pai_messages = _mcp.map_from_mcp_params(params) @@ -265,7 +239,7 @@ async def _sampling_callback( if stop_sequences := params.stopSequences: # pragma: no branch model_settings['stop_sequences'] = stop_sequences - model_response = await sampling_model.request( + model_response = await self.sampling_model.request( pai_messages, model_settings, models.ModelRequestParameters(), @@ -273,7 +247,7 @@ async def _sampling_callback( return mcp_types.CreateMessageResult( role='assistant', content=_mcp.map_from_model_response(model_response), - model=sampling_model.model_name, + model=self.sampling_model.model_name, ) def _map_tool_result_part( diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py index 66caa1678..65293ca17 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from types import TracebackType -from typing import TYPE_CHECKING, Any, Generic, Literal +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal from pydantic_core import SchemaValidator from typing_extensions import Self @@ -78,3 +78,6 @@ async def call_tool( self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any ) -> Any: raise NotImplementedError() + + def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + return visitor(self) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 8930c36c2..ea706d0d4 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -5,7 +5,7 @@ from contextlib import AsyncExitStack from dataclasses import dataclass from types import TracebackType -from typing import Any +from typing import Any, Callable from pydantic_core import SchemaValidator from typing_extensions import Self @@ -89,6 +89,10 @@ async def call_tool( ) -> Any: return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs) + def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + for toolset in self.toolsets: + toolset.accept(visitor) + def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]: try: return self._toolset_per_tool_name[name] diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 354de8ebc..71fcca012 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from types import TracebackType -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable from pydantic_core import SchemaValidator from typing_extensions import Self @@ -58,5 +58,8 @@ async def call_tool( ) -> Any: return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) + def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + return self.wrapped.accept(visitor) + def __getattr__(self, item: str): return getattr(self.wrapped, item) # pragma: no cover diff --git a/tests/test_examples.py b/tests/test_examples.py index 03ec1c342..5f30d4b87 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -6,7 +6,6 @@ import shutil import sys from collections.abc import AsyncIterator, Iterable, Sequence -from contextlib import nullcontext from dataclasses import dataclass from inspect import FrameInfo from io import StringIO @@ -264,7 +263,6 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: class MockMCPServer(AbstractToolset[Any]): is_running = True - override_sampling_model = nullcontext async def __aenter__(self) -> MockMCPServer: return self diff --git a/tests/test_mcp.py b/tests/test_mcp.py index bb8181866..2394fdd13 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -68,7 +68,7 @@ def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: @pytest.fixture def run_context(model: Model) -> RunContext[int]: - return RunContext(deps=0, model=model, usage=Usage(), sampling_model=model) + return RunContext(deps=0, model=model, usage=Usage()) async def test_stdio_server(run_context: RunContext[int]): diff --git a/tests/test_toolset.py b/tests/test_toolset.py index 4122df05b..02fff0e89 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolset.py @@ -22,7 +22,6 @@ def build_run_context(deps: T) -> RunContext[T]: deps=deps, model=TestModel(), usage=Usage(), - sampling_model=TestModel(), prompt=None, messages=[], run_step=0, From f7db040033781438f3e282a3b8fc684e8c0c6775 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 3 Jul 2025 23:43:30 +0000 Subject: [PATCH 78/90] Allow OutputSpec to be nested --- pydantic_ai_slim/pydantic_ai/_output.py | 160 +++++++++++++----------- pydantic_ai_slim/pydantic_ai/output.py | 14 ++- tests/typed_agent.py | 16 ++- 3 files changed, 110 insertions(+), 80 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ab505e3a4..4153ca30a 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field, replace -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Self, Union, cast, overload from pydantic import TypeAdapter, ValidationError from pydantic_core import SchemaValidator @@ -26,6 +26,7 @@ TextOutput, TextOutputFunc, ToolOutput, + _OutputSpecItem, # type: ignore[reportPrivateUsage] ) from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition from .toolsets import AbstractToolset @@ -233,7 +234,7 @@ def build( # noqa: C901 else: other_outputs.append(output) - toolset = cls._build_toolset(tool_outputs + other_outputs, name=name, description=description, strict=strict) + toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict) if len(text_outputs) > 0: if len(text_outputs) > 1: @@ -268,73 +269,6 @@ def build( # noqa: C901 raise UserError('At least one output type must be provided.') - @staticmethod - def _build_toolset( - outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], - name: str | None = None, - description: str | None = None, - strict: bool | None = None, - ) -> OutputToolset[Any] | None: - if len(outputs) == 0: - return None - - processors: dict[str, ObjectOutputProcessor[Any]] = {} - tool_defs: list[ToolDefinition] = [] - - default_name = name or DEFAULT_OUTPUT_TOOL_NAME - default_description = description - default_strict = strict - - multiple = len(outputs) > 1 - for output in outputs: - name = None - description = None - strict = None - if isinstance(output, ToolOutput): - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads - name = output.name - description = output.description - strict = output.strict - - output = output.output - - if name is None: - name = default_name - if multiple: - name += f'_{output.__name__}' - - i = 1 - original_name = name - while name in processors: - i += 1 - name = f'{original_name}_{i}' - - description = description or default_description - if strict is None: - strict = default_strict - - processor = ObjectOutputProcessor(output=output, description=description, strict=strict) - object_def = processor.object_def - - description = object_def.description - if not description: - description = DEFAULT_OUTPUT_TOOL_DESCRIPTION - if multiple: - description = f'{object_def.name}: {description}' - - tool_def = ToolDefinition( - name=name, - description=description, - parameters_json_schema=object_def.json_schema, - strict=object_def.strict, - outer_typed_dict_key=processor.outer_typed_dict_key, - kind='output', - ) - processors[name] = processor - tool_defs.append(tool_def) - - return OutputToolset(processors=processors, tool_defs=tool_defs) - @staticmethod def _build_processor( outputs: Sequence[OutputTypeOrFunction[OutputDataT]], @@ -908,6 +842,74 @@ class OutputToolset(AbstractToolset[AgentDepsT]): max_retries: int = field(default=1) output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list) + @classmethod + def build( + cls, + outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> Self | None: + if len(outputs) == 0: + return None + + processors: dict[str, ObjectOutputProcessor[Any]] = {} + tool_defs: list[ToolDefinition] = [] + + default_name = name or DEFAULT_OUTPUT_TOOL_NAME + default_description = description + default_strict = strict + + multiple = len(outputs) > 1 + for output in outputs: + name = None + description = None + strict = None + if isinstance(output, ToolOutput): + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads + name = output.name + description = output.description + strict = output.strict + + output = output.output + + if name is None: + name = default_name + if multiple: + name += f'_{output.__name__}' + + i = 1 + original_name = name + while name in processors: + i += 1 + name = f'{original_name}_{i}' + + description = description or default_description + if strict is None: + strict = default_strict + + processor = ObjectOutputProcessor(output=output, description=description, strict=strict) + object_def = processor.object_def + + description = object_def.description + if not description: + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION + if multiple: + description = f'{object_def.name}: {description}' + + tool_def = ToolDefinition( + name=name, + description=description, + parameters_json_schema=object_def.json_schema, + strict=object_def.strict, + outer_typed_dict_key=processor.outer_typed_dict_key, + kind='output', + ) + processors[name] = processor + tool_defs.append(tool_def) + + return cls(processors=processors, tool_defs=tool_defs) + def __init__( self, tool_defs: list[ToolDefinition], @@ -942,17 +944,29 @@ async def call_tool( return output -def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: - outputs: Sequence[T] +@overload +def _flatten_output_spec( + output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]], +) -> Sequence[OutputTypeOrFunction[T]]: ... + + +@overload +def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ... + + +def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: + outputs: Sequence[OutputSpec[T]] if isinstance(output_spec, Sequence): outputs = output_spec else: outputs = (output_spec,) - outputs_flat: list[T] = [] + outputs_flat: list[_OutputSpecItem[T]] = [] for output in outputs: + if isinstance(output, Sequence): + outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output))) if union_types := _utils.get_union_args(output): outputs_flat.extend(union_types) else: - outputs_flat.append(output) + outputs_flat.append(cast(_OutputSpecItem[T], output)) return outputs_flat diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 316921781..7ddd889d1 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -267,15 +267,17 @@ def split_into_words(text: str) -> list[str]: """The function that will be called to process the model's plain text output. The function must take a single string argument.""" +_OutputSpecItem = TypeAliasType( + '_OutputSpecItem', + Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], NativeOutput[T_co], PromptedOutput[T_co], TextOutput[T_co]], + type_params=(T_co,), +) + OutputSpec = TypeAliasType( 'OutputSpec', Union[ - OutputTypeOrFunction[T_co], - ToolOutput[T_co], - NativeOutput[T_co], - PromptedOutput[T_co], - TextOutput[T_co], - Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], + _OutputSpecItem[T_co], + Sequence['OutputSpec[T_co]'], ], type_params=(T_co,), ) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 941bbf987..9c15d53cd 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -10,7 +10,7 @@ from pydantic_ai import Agent, ModelRetry, RunContext, Tool from pydantic_ai.agent import AgentRunResult -from pydantic_ai.output import TextOutput, ToolOutput +from pydantic_ai.output import DeferredToolCalls, TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition # Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True @@ -212,6 +212,14 @@ def my_method(self) -> bool: assert_type( complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] ) + + complex_deferred_output_agent = Agent[ + None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls + ](output_type=[complex_output_agent.output_type, DeferredToolCalls]) + assert_type( + complex_deferred_output_agent, + Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls], + ) else: # pyright is able to correctly infer the type here async_int_function_agent = Agent(output_type=foobar_plain) @@ -231,6 +239,12 @@ def my_method(self) -> bool: complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] ) + complex_deferred_output_agent = Agent(output_type=[complex_output_agent.output_type, DeferredToolCalls]) + assert_type( + complex_deferred_output_agent, + Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls], + ) + Tool(foobar_ctx, takes_ctx=True) Tool(foobar_ctx) From fe071494d80f5e25c7fd4651b14a3ec6c6aa5afe Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 3 Jul 2025 23:52:47 +0000 Subject: [PATCH 79/90] Document Agent.__aenter__ --- pydantic_ai_slim/pydantic_ai/agent.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 85b787cda..bd61f63e0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1788,6 +1788,7 @@ def is_end_node( return isinstance(node, End) async def __aenter__(self) -> Self: + """Enter the agent. This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered with the agent so they can be used in a run.""" if self._running_count == 0: self._exit_stack = AsyncExitStack() await self._exit_stack.enter_async_context(self._toolset) @@ -1828,6 +1829,9 @@ async def run_mcp_servers( ) -> AsyncIterator[None]: """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. + Deprecated: use [`async with agent`][pydantic_ai.agent.Agent.__aenter__] instead. + If you need to set a sampling model on all MCP servers, use [`agent.set_mcp_sampling_model(...)`][pydantic_ai.agent.Agent.set_mcp_sampling_model]. + Returns: a context manager to start and shutdown the servers. """ try: From a0f4678d437db66fb87808d4d053ea9c41d2cf00 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 3 Jul 2025 23:59:55 +0000 Subject: [PATCH 80/90] Import Self from typing_extensions instead of typing --- pydantic_ai_slim/pydantic_ai/_output.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 4153ca30a..a0d8a0596 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -5,11 +5,11 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field, replace -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Self, Union, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload from pydantic import TypeAdapter, ValidationError from pydantic_core import SchemaValidator -from typing_extensions import TypedDict, TypeVar, assert_never +from typing_extensions import Self, TypedDict, TypeVar, assert_never from . import _function_schema, _utils, messages as _messages from ._run_context import AgentDepsT, RunContext From db82d00f97ac156b217ea7a3a95a05d712138945 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 15:05:16 +0000 Subject: [PATCH 81/90] Actually use Agent.prepare_output_tools --- pydantic_ai_slim/pydantic_ai/agent.py | 6 +++--- tests/test_streaming.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index bd61f63e0..211c51705 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -733,12 +733,12 @@ async def main(): output_toolset = self._output_toolset if output_schema != self._output_schema or output_validators: - output_toolset = output_schema.toolset + output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) if output_toolset: output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators - if self._prepare_output_tools: - output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) + if output_toolset and self._prepare_output_tools: + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( diff --git a/tests/test_streaming.py b/tests/test_streaming.py index cc5a0c4b3..b4980ecea 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1109,7 +1109,7 @@ async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolD @agent.tool_plain(prepare=prepare_tool) def my_tool(x: int) -> int: - return x + 1 + return x + 1 # pragma: no cover async with agent.run_stream('Hello') as result: assert not result.is_complete @@ -1143,7 +1143,7 @@ async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolD @agent.tool_plain(prepare=prepare_tool) def my_tool(x: int) -> int: - return x + 1 + return x + 1 # pragma: no cover outputs: list[str | DeferredToolCalls] = [] events: list[Any] = [] From dea80506f68ee62245ac5a336a5f4b3ce8a47eae Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 15:19:27 +0000 Subject: [PATCH 82/90] Update test to account for fact that text output with early end_strategy stops later tools from being called --- tests/models/test_gemini.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 0008c61bc..43e10cb99 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -965,12 +965,47 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): @agent.tool_plain() def get_location(loc_name: str) -> str: - return f'Location for {loc_name}' + return f'Location for {loc_name}' # pragma: no cover async with agent.run_stream('Hello') as result: data = await result.get_output() assert data == 'Hello foo' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content='Hello foo'), + ToolCallPart( + tool_name='get_location', + args={'loc_name': 'San Fransisco'}, + tool_call_id=IsStr(), + ), + ], + usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + model_name='gemini-1.5-flash', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_location', + content='Tool not executed - a final result was already processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) async def test_empty_text_ignored(): From 131a325f4184fba5774949874f499cf20326bedd Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 16:54:30 +0000 Subject: [PATCH 83/90] Improve test coverage --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 16 +--- pydantic_ai_slim/pydantic_ai/_output.py | 29 +++--- pydantic_ai_slim/pydantic_ai/agent.py | 7 +- pydantic_ai_slim/pydantic_ai/mcp.py | 13 +-- pydantic_ai_slim/pydantic_ai/result.py | 12 +-- pydantic_ai_slim/pydantic_ai/toolsets/_run.py | 31 ++++--- tests/test_agent.py | 47 ++++++++++ tests/test_examples.py | 2 - tests/test_logfire.py | 93 +++++++++++++------ tests/test_tools.py | 45 +++++++++ 10 files changed, 204 insertions(+), 91 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e3bcee4aa..999f1a5f5 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -538,8 +538,8 @@ async def _handle_text_response( text = '\n\n'.join(texts) try: + run_context = build_run_context(ctx) if isinstance(output_schema, _output.TextOutputSchema): - run_context = build_run_context(ctx) result_data = await output_schema.process(text, run_context) else: m = _messages.RetryPromptPart( @@ -547,7 +547,8 @@ async def _handle_text_response( ) raise ToolRetryError(m) - result_data = await _validate_output(result_data, ctx, None) + for validator in ctx.deps.output_validators: + result_data = await validator.validate(result_data, run_context) except ToolRetryError as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) @@ -834,17 +835,6 @@ async def _call_tool( return await toolset.call_tool(run_context, tool_call.tool_name, args_dict) -async def _validate_output( - result_data: T, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], - tool_call: _messages.ToolCallPart | None, -) -> T: - for validator in ctx.deps.output_validators: - run_context = build_run_context(ctx) - result_data = await validator.validate(result_data, tool_call, run_context) - return result_data - - @dataclasses.dataclass class _RunMessages: messages: list[_messages.ModelMessage] diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index a0d8a0596..b212750ac 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -4,7 +4,7 @@ import json from abc import ABC, abstractmethod from collections.abc import Awaitable, Sequence -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload from pydantic import TypeAdapter, ValidationError @@ -83,7 +83,6 @@ def __post_init__(self): async def validate( self, result: T, - tool_call: _messages.ToolCallPart | None, run_context: RunContext[AgentDepsT], wrap_validation_errors: bool = True, ) -> T: @@ -91,7 +90,6 @@ async def validate( Args: result: The result data after Pydantic validation the message content. - tool_call: The original tool call message, `None` if there was no tool call. run_context: The current run context. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -99,12 +97,7 @@ async def validate( Result of either the validated result data (ok) or a retry message (Err). """ if self._takes_ctx: - ctx = ( - replace(run_context, tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id) - if tool_call - else run_context - ) - args = ctx, result + args = run_context, result else: args = (result,) @@ -117,10 +110,12 @@ async def validate( result_data = await _utils.run_in_executor(function, *args) except ModelRetry as r: if wrap_validation_errors: - m = _messages.RetryPromptPart(content=r.message) - if tool_call is not None: - m.tool_name = tool_call.tool_name - m.tool_call_id = tool_call.tool_call_id + m = _messages.RetryPromptPart( + content=r.message, + tool_name=run_context.tool_name, + ) + if run_context.tool_call_id: + m.tool_call_id = run_context.tool_call_id raise ToolRetryError(m) from r else: raise r @@ -190,7 +185,7 @@ def build( # noqa: C901 if output := next((output for output in outputs if isinstance(output, NativeOutput)), None): if len(outputs) > 1: - raise UserError('`NativeOutput` must be the only output type.') + raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover return NativeOutputSchema( processor=cls._build_processor( @@ -203,7 +198,7 @@ def build( # noqa: C901 ) elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None): if len(outputs) > 1: - raise UserError('`PromptedOutput` must be the only output type.') + raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover return PromptedOutputSchema( processor=cls._build_processor( @@ -940,7 +935,7 @@ async def call_tool( ) -> Any: output = await self.processors[name].call(tool_args, ctx) for validator in self.output_validators: - output = await validator.validate(output, None, ctx, wrap_validation_errors=False) + output = await validator.validate(output, ctx, wrap_validation_errors=False) return output @@ -965,7 +960,7 @@ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem for output in outputs: if isinstance(output, Sequence): outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output))) - if union_types := _utils.get_union_args(output): + elif union_types := _utils.get_union_args(output): outputs_flat.extend(union_types) else: outputs_flat.append(cast(_OutputSpecItem[T], output)) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 211c51705..f2e351bd9 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -370,11 +370,10 @@ def __init__( output_retries = result_retries if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): + if toolsets is not None: # pragma: no cover + raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) - if toolsets is None: - toolsets = mcp_servers - else: - toolsets = [*toolsets, *mcp_servers] + toolsets = mcp_servers _utils.validate_empty_kwargs(_deprecated_kwargs) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 73ceffc2b..189fd9d72 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -59,6 +59,7 @@ class MCPServer(AbstractToolset[Any], ABC): process_tool_call: ToolProcessFunc[Any] | None = None allow_sampling: bool = True max_retries: int = 1 + sampling_model: models.Model | None = None # } end of "abstract fields" _running_count: int = 0 @@ -67,7 +68,6 @@ class MCPServer(AbstractToolset[Any], ABC): _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] _exit_stack: AsyncExitStack - sampling_model: models.Model | None = None @abstractmethod @asynccontextmanager @@ -83,11 +83,6 @@ async def client_streams( raise NotImplementedError('MCP Server subclasses must implement this method.') yield - @property - def is_running(self) -> bool: - """Check if the MCP server is running.""" - return bool(self._running_count) - @property def name(self) -> str: return repr(self) @@ -373,6 +368,9 @@ async def main(): max_retries: int = 1 """The maximum number of times to retry a tool call.""" + sampling_model: models.Model | None = None + """The model to use for sampling.""" + @asynccontextmanager async def client_streams( self, @@ -471,6 +469,9 @@ class _MCPServerHTTP(MCPServer): max_retries: int = 1 """The maximum number of times to retry a tool call.""" + sampling_model: models.Model | None = None + """The model to use for sampling.""" + @property @abstractmethod def _transport_client( diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 5f732aa86..8615a8f2e 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -113,8 +113,8 @@ async def _validate_response( return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict) elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): if not self._output_schema.allows_deferred_tool_calls: - raise exceptions.UserError( - 'There are deferred tool calls but DeferredToolCalls is not among output types.' + raise exceptions.UserError( # pragma: no cover + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' ) return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): @@ -124,7 +124,7 @@ async def _validate_response( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) for validator in self._output_validators: - result_data = await validator.validate(result_data, None, self._run_ctx) + result_data = await validator.validate(result_data, self._run_ctx) return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -450,7 +450,7 @@ async def validate_structured_output( elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): if not self._output_schema.allows_deferred_tool_calls: raise exceptions.UserError( - 'There are deferred tool calls but DeferredToolCalls is not among output types.' + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' ) return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): @@ -460,7 +460,7 @@ async def validate_structured_output( text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) for validator in self._output_validators: - result_data = await validator.validate(result_data, None, self._run_ctx) # pragma: no cover + result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -469,7 +469,7 @@ async def validate_structured_output( async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: - text = await validator.validate(text, None, self._run_ctx) # pragma: no cover + text = await validator.validate(text, self._run_ctx) # pragma: no cover return text async def _marked_completed(self, message: _messages.ModelResponse) -> None: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_run.py b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py index 411ee35c9..a7da57999 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/_run.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py @@ -122,21 +122,22 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon if current_retry == max_retries: raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e else: - if ctx.tool_call_id: - if isinstance(e, ValidationError): - m = _messages.RetryPromptPart( - tool_name=name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=ctx.tool_call_id, - ) - e = ToolRetryError(m) - elif isinstance(e, ModelRetry): - m = _messages.RetryPromptPart( - tool_name=name, - content=e.message, - tool_call_id=ctx.tool_call_id, - ) - e = ToolRetryError(m) + if isinstance(e, ValidationError): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.errors(include_url=False, include_context=False), + ) + if ctx.tool_call_id: # pragma: no branch + m.tool_call_id = ctx.tool_call_id + e = ToolRetryError(m) + elif isinstance(e, ModelRetry): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.message, + ) + if ctx.tool_call_id: # pragma: no branch + m.tool_call_id = ctx.tool_call_id + e = ToolRetryError(m) self._retries[name] = current_retry + 1 raise e diff --git a/tests/test_agent.py b/tests/test_agent.py index 52c913aa6..fb7a188c1 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -24,6 +24,7 @@ ToolOutputSchema, ) from pydantic_ai.agent import AgentRunResult +from pydantic_ai.mcp import MCPServerStdio from pydantic_ai.messages import ( BinaryContent, ImageUrl, @@ -46,7 +47,9 @@ from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.combined import CombinedToolset from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.prefixed import PrefixedToolset from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -3439,6 +3442,14 @@ def test_deprecated_kwargs_still_work(): assert issubclass(w[0].category, DeprecationWarning) assert '`result_retries` is deprecated' in str(w[0].message) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + agent = Agent('test', mcp_servers=[MCPServerStdio('python', ['-m', 'tests.mcp_server'])]) # type: ignore[call-arg] + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert '`mcp_servers` is deprecated' in str(w[0].message) + def test_deprecated_kwargs_mixed_valid_invalid(): """Test that mix of valid deprecated and invalid kwargs raises error for invalid ones.""" @@ -3583,3 +3594,39 @@ async def only_if_plan_presented( ), ] ) + + +async def test_reentrant_context_manager(): + agent = Agent('test') + async with agent: + async with agent: + pass + + +def test_set_mcp_sampling_model(): + test_model = TestModel() + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'], sampling_model=test_model) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix_')]) + agent = Agent(None, toolsets=[toolset]) + + with pytest.raises(UserError, match='No sampling model provided and no model set on the agent.'): + agent.set_mcp_sampling_model() + assert server1.sampling_model is None + assert server2.sampling_model is test_model + + agent.model = test_model + agent.set_mcp_sampling_model() + assert server1.sampling_model is test_model + assert server2.sampling_model is test_model + + function_model = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Hello')])) + with agent.override(model=function_model): + agent.set_mcp_sampling_model() + assert server1.sampling_model is function_model + assert server2.sampling_model is function_model + + function_model2 = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Goodbye')])) + agent.set_mcp_sampling_model(function_model2) + assert server1.sampling_model is function_model2 + assert server2.sampling_model is function_model2 diff --git a/tests/test_examples.py b/tests/test_examples.py index 5f30d4b87..9763e8f57 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -262,8 +262,6 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: class MockMCPServer(AbstractToolset[Any]): - is_running = True - async def __aenter__(self) -> MockMCPServer: return self diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 10e7a9ec1..e717e429d 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -10,6 +10,7 @@ from pydantic_ai import Agent from pydantic_ai._utils import get_traceparent +from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel from pydantic_ai.models.test import TestModel @@ -529,10 +530,11 @@ async def test_feedback(capfire: CaptureLogfire) -> None: @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') -@pytest.mark.parametrize('include_content', [True, False]) +@pytest.mark.parametrize('include_content,tool_error', [(True, False), (True, True), (False, False), (False, True)]) def test_include_tool_args_span_attributes( get_logfire_summary: Callable[[], LogfireSummary], include_content: bool, + tool_error: bool, ) -> None: """Test that tool arguments are included/excluded in span attributes based on instrumentation settings.""" @@ -543,41 +545,76 @@ def test_include_tool_args_span_attributes( @my_agent.tool_plain async def add_numbers(x: int, y: int) -> int: """Add two numbers together.""" + if tool_error: + raise ModelRetry('Tool error') return x + y - result = my_agent.run_sync('Add 42 and 42') - assert result.output == snapshot('{"add_numbers":84}') + try: + result = my_agent.run_sync('Add 42 and 42') + assert result.output == snapshot('{"add_numbers":84}') + except UnexpectedModelBehavior: + if not tool_error: + raise summary = get_logfire_summary() - [tool_attributes] = [ + tool_attributes = next( attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'add_numbers' - ] + ) if include_content: - assert tool_attributes == snapshot( - { - 'gen_ai.tool.name': 'add_numbers', - 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"x":42,"y":42}', - 'tool_response': '84', - 'logfire.msg': 'running tool: add_numbers', - 'logfire.json_schema': IsJson( - snapshot( - { - 'type': 'object', - 'properties': { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ) - ), - 'logfire.span_type': 'span', - } - ) + if tool_error: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"x":42,"y":42}', + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': """\ +Tool error + +Fix the errors and try again.\ +""", + } + ) + else: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"x":42,"y":42}', + 'tool_response': '84', + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + } + ) else: assert tool_attributes == snapshot( { diff --git a/tests/test_tools.py b/tests/test_tools.py index fe582f717..c7c41b3fd 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1232,6 +1232,51 @@ class MyModel(BaseModel): assert result.output == snapshot(MyModel(foo='a')) +def test_deferred_tool_with_tool_output_type(): + class MyModel(BaseModel): + foo: str + + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent( + TestModel(call_tools=[]), + output_type=[[ToolOutput(MyModel), ToolOutput(MyModel)], DeferredToolCalls], + toolsets=[deferred_toolset], + ) + + result = agent.run_sync('Hello') + assert result.output == snapshot(MyModel(foo='a')) + + +async def test_deferred_tool_without_output_type(): + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(), toolsets=[deferred_toolset]) + + msg = 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + + with pytest.raises(UserError, match=msg): + await agent.run('Hello') + + with pytest.raises(UserError, match=msg): + async with agent.run_stream('Hello') as result: + await result.get_output() + + def test_output_type_deferred_tool_calls_by_itself(): with pytest.raises(UserError, match='At least one output type must be provided other than `DeferredToolCalls`.'): Agent(TestModel(), output_type=DeferredToolCalls) From 778962c6470b3b21ad6fcbc9399045afc2201ce0 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 17:13:52 +0000 Subject: [PATCH 84/90] Make Agent MCP-related tests only run when mcp can be imported --- tests/test_agent.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index fb7a188c1..7277b9745 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -24,7 +24,6 @@ ToolOutputSchema, ) from pydantic_ai.agent import AgentRunResult -from pydantic_ai.mcp import MCPServerStdio from pydantic_ai.messages import ( BinaryContent, ImageUrl, @@ -3442,13 +3441,18 @@ def test_deprecated_kwargs_still_work(): assert issubclass(w[0].category, DeprecationWarning) assert '`result_retries` is deprecated' in str(w[0].message) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + try: + from pydantic_ai.mcp import MCPServerStdio - agent = Agent('test', mcp_servers=[MCPServerStdio('python', ['-m', 'tests.mcp_server'])]) # type: ignore[call-arg] - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert '`mcp_servers` is deprecated' in str(w[0].message) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + agent = Agent('test', mcp_servers=[MCPServerStdio('python', ['-m', 'tests.mcp_server'])]) # type: ignore[call-arg] + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert '`mcp_servers` is deprecated' in str(w[0].message) + except ImportError: + pass def test_deprecated_kwargs_mixed_valid_invalid(): @@ -3604,6 +3608,11 @@ async def test_reentrant_context_manager(): def test_set_mcp_sampling_model(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: + return + test_model = TestModel() server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'], sampling_model=test_model) From e6575a98ebfc20c3dc654382f468d3f93321bcbb Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 18:25:49 +0000 Subject: [PATCH 85/90] Add tests --- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- pydantic_ai_slim/pydantic_ai/mcp.py | 5 + .../pydantic_ai/toolsets/combined.py | 12 +- tests/test_agent.py | 21 +- tests/test_logfire.py | 2 +- tests/test_toolset.py | 560 +++++++++++++++++- 6 files changed, 588 insertions(+), 14 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index b212750ac..66789c072 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -114,7 +114,7 @@ async def validate( content=r.message, tool_name=run_context.tool_name, ) - if run_context.tool_call_id: + if run_context.tool_call_id: # pragma: no cover m.tool_call_id = run_context.tool_call_id raise ToolRetryError(m) from r else: diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 189fd9d72..ce581a6fe 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -218,6 +218,11 @@ async def __aexit__( if self._running_count <= 0: await self._exit_stack.aclose() + @property + def is_running(self) -> bool: + """Check if the MCP server is running.""" + return bool(self._running_count) + async def _sampling_callback( self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index ea706d0d4..6200603ce 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -24,11 +24,11 @@ class CombinedToolset(AbstractToolset[AgentDepsT]): toolsets: list[AbstractToolset[AgentDepsT]] _toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]] _exit_stack: AsyncExitStack | None - _running_count: int + _entered_count: int def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): self._exit_stack = None - self._running_count = 0 + self._entered_count = 0 self.toolsets = list(toolsets) self._toolset_per_tool_name = {} @@ -44,18 +44,18 @@ def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): self._toolset_per_tool_name[name] = toolset async def __aenter__(self) -> Self: - if self._running_count == 0: + if self._entered_count == 0: self._exit_stack = AsyncExitStack() for toolset in self.toolsets: await self._exit_stack.enter_async_context(toolset) - self._running_count += 1 + self._entered_count += 1 return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None ) -> bool | None: - self._running_count -= 1 - if self._running_count <= 0 and self._exit_stack is not None: + self._entered_count -= 1 + if self._entered_count <= 0 and self._exit_stack is not None: await self._exit_stack.aclose() self._exit_stack = None return None diff --git a/tests/test_agent.py b/tests/test_agent.py index 7277b9745..95a6c7ef3 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3600,11 +3600,24 @@ async def only_if_plan_presented( ) -async def test_reentrant_context_manager(): - agent = Agent('test') +async def test_context_manager(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: + return + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + agent = Agent('test', toolsets=[toolset]) + async with agent: + assert server1.is_running + assert server2.is_running + async with agent: - pass + assert server1.is_running + assert server2.is_running def test_set_mcp_sampling_model(): @@ -3616,7 +3629,7 @@ def test_set_mcp_sampling_model(): test_model = TestModel() server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'], sampling_model=test_model) - toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix_')]) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) agent = Agent(None, toolsets=[toolset]) with pytest.raises(UserError, match='No sampling model provided and no model set on the agent.'): diff --git a/tests/test_logfire.py b/tests/test_logfire.py index e717e429d..8acf91eaa 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -554,7 +554,7 @@ async def add_numbers(x: int, y: int) -> int: assert result.output == snapshot('{"add_numbers":84}') except UnexpectedModelBehavior: if not tool_error: - raise + raise # pragma: no cover summary = get_logfire_summary() diff --git a/tests/test_toolset.py b/tests/test_toolset.py index 02fff0e89..a0e4bb882 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolset.py @@ -1,15 +1,22 @@ from __future__ import annotations -from dataclasses import dataclass, replace -from typing import TypeVar +from collections.abc import Awaitable +from dataclasses import dataclass, field, replace +from typing import Any, Callable, TypeVar import pytest from inline_snapshot import snapshot from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import UserError from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.combined import CombinedToolset +from pydantic_ai.toolsets.filtered import FilteredToolset from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.prefixed import PrefixedToolset +from pydantic_ai.toolsets.prepared import PreparedToolset +from pydantic_ai.toolsets.processed import ProcessedToolset from pydantic_ai.usage import Usage pytestmark = pytest.mark.anyio @@ -127,3 +134,552 @@ def subtract(a: int, b: int) -> int: bar_foo_toolset = await foo_toolset.prepare_for_run(bar_context) assert bar_foo_toolset == bar_toolset + + +async def test_prepared_toolset_user_error_add_new_tools(): + """Test that PreparedToolset raises UserError when prepare function tries to add new tools.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + async def prepare_add_new_tool(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Try to add a new tool that wasn't in the original set + new_tool = ToolDefinition( + name='new_tool', + description='A new tool', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + ) + return tool_defs + [new_tool] + + prepared_toolset = PreparedToolset(base_toolset, prepare_add_new_tool) + + with pytest.raises(UserError, match='Prepare function is not allowed to change tool names or add new tools.'): + await prepared_toolset.prepare_for_run(context) + + +async def test_prepared_toolset_user_error_change_tool_names(): + """Test that PreparedToolset raises UserError when prepare function tries to change tool names.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + async def prepare_change_names(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Try to change the name of an existing tool + modified_tool_defs: list[ToolDefinition] = [] + for tool_def in tool_defs: + if tool_def.name == 'add': + modified_tool_defs.append(replace(tool_def, name='modified_add')) + else: + modified_tool_defs.append(tool_def) + return modified_tool_defs + + prepared_toolset = PreparedToolset(base_toolset, prepare_change_names) + + with pytest.raises(UserError, match='Prepare function is not allowed to change tool names or add new tools.'): + await prepared_toolset.prepare_for_run(context) + + +async def test_prepared_toolset_allows_removing_tools(): + """Test that PreparedToolset allows removing tools from the original set.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + @base_toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b + + @base_toolset.tool + def multiply(a: int, b: int) -> int: + """Multiply two numbers""" + return a * b + + async def prepare_remove_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Remove the 'subtract' tool, keep 'add' and 'multiply' + return [tool_def for tool_def in tool_defs if tool_def.name != 'subtract'] + + prepared_toolset = PreparedToolset(base_toolset, prepare_remove_tools) + + # This should not raise an error + run_toolset = await prepared_toolset.prepare_for_run(context) + + # Verify that only 'add' and 'multiply' tools are available + assert set(run_toolset.tool_names) == {'add', 'multiply'} + assert len(run_toolset.tool_defs) == 2 + + # Verify that the tools still work + assert await run_toolset.call_tool(context, 'add', {'a': 5, 'b': 3}) == 8 + assert await run_toolset.call_tool(context, 'multiply', {'a': 4, 'b': 2}) == 8 + + +async def test_prefixed_toolset_tool_defs(): + """Test that PrefixedToolset correctly prefixes tool definitions.""" + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + @base_toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b + + prefixed_toolset = PrefixedToolset(base_toolset, 'math') + + # Check that tool names are prefixed + assert prefixed_toolset.tool_names == ['math_add', 'math_subtract'] + + # Check that tool definitions have prefixed names + tool_defs = prefixed_toolset.tool_defs + assert len(tool_defs) == 2 + + add_def = next(td for td in tool_defs if td.name == 'math_add') + subtract_def = next(td for td in tool_defs if td.name == 'math_subtract') + + assert add_def.name == 'math_add' + assert add_def.description == 'Add two numbers' + assert subtract_def.name == 'math_subtract' + assert subtract_def.description == 'Subtract two numbers' + + +async def test_prefixed_toolset_call_tools(): + """Test that PrefixedToolset correctly calls tools with prefixed names.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + @base_toolset.tool + def multiply(a: int, b: int) -> int: + """Multiply two numbers""" + return a * b + + prefixed_toolset = PrefixedToolset(base_toolset, 'calc') + + # Test calling tools with prefixed names + result = await prefixed_toolset.call_tool(context, 'calc_add', {'a': 5, 'b': 3}) + assert result == 8 + + result = await prefixed_toolset.call_tool(context, 'calc_multiply', {'a': 4, 'b': 2}) + assert result == 8 + + +async def test_prefixed_toolset_prepare_for_run(): + """Test that PrefixedToolset correctly prepares for run with prefixed tools.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + prefixed_toolset = PrefixedToolset(base_toolset, 'test') + + # Prepare for run + run_toolset = await prefixed_toolset.prepare_for_run(context) + + # Verify that the run toolset has prefixed tools + assert run_toolset.tool_names == ['test_add'] + assert len(run_toolset.tool_defs) == 1 + assert run_toolset.tool_defs[0].name == 'test_add' + + # Verify that the tool still works + result = await run_toolset.call_tool(context, 'test_add', {'a': 10, 'b': 5}) + assert result == 15 + + +async def test_prefixed_toolset_error_invalid_prefix(): + """Test that PrefixedToolset raises ValueError for tool names that don't start with the prefix.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + prefixed_toolset = PrefixedToolset(base_toolset, 'math') + + # Test calling with wrong prefix + with pytest.raises(ValueError, match="Tool name 'wrong_add' does not start with prefix 'math_'"): + await prefixed_toolset.call_tool(context, 'wrong_add', {'a': 1, 'b': 2}) + + # Test calling with no prefix + with pytest.raises(ValueError, match="Tool name 'add' does not start with prefix 'math_'"): + await prefixed_toolset.call_tool(context, 'add', {'a': 1, 'b': 2}) + + # Test calling with partial prefix + with pytest.raises(ValueError, match="Tool name 'mat_add' does not start with prefix 'math_'"): + await prefixed_toolset.call_tool(context, 'mat_add', {'a': 1, 'b': 2}) + + +async def test_prefixed_toolset_empty_prefix(): + """Test that PrefixedToolset works correctly with an empty prefix.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + prefixed_toolset = PrefixedToolset(base_toolset, '') + + # Check that tool names have empty prefix (just underscore) + assert prefixed_toolset.tool_names == ['_add'] + + # Test calling the tool + result = await prefixed_toolset.call_tool(context, '_add', {'a': 3, 'b': 4}) + assert result == 7 + + # Test error for wrong name + with pytest.raises(ValueError, match="Tool name 'add' does not start with prefix '_'"): + await prefixed_toolset.call_tool(context, 'add', {'a': 1, 'b': 2}) + + +async def test_comprehensive_toolset_composition(): # noqa: C901 + """Test that all toolsets can be composed together and work correctly.""" + + @dataclass + class TestDeps: + user_role: str = 'user' + enable_advanced: bool = True + log_calls: bool = False + log: list[str] = field(default_factory=list) + + # Create first FunctionToolset with basic math operations + math_toolset = FunctionToolset[TestDeps]() + + @math_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + @math_toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b + + @math_toolset.tool + def multiply(a: int, b: int) -> int: + """Multiply two numbers""" + return a * b + + # Create second FunctionToolset with string operations + string_toolset = FunctionToolset[TestDeps]() + + @string_toolset.tool + def concat(s1: str, s2: str) -> str: + """Concatenate two strings""" + return s1 + s2 + + @string_toolset.tool + def uppercase(text: str) -> str: + """Convert text to uppercase""" + return text.upper() + + @string_toolset.tool + def reverse(text: str) -> str: + """Reverse a string""" + return text[::-1] + + # Create third FunctionToolset with advanced operations + advanced_toolset = FunctionToolset[TestDeps]() + + @advanced_toolset.tool + def power(base: int, exponent: int) -> int: + """Calculate base raised to the power of exponent""" + return base**exponent + + @advanced_toolset.tool + def factorial(n: int) -> int: + """Calculate factorial of n""" + + def _fact(x: int) -> int: + if x <= 1: + return 1 + return x * _fact(x - 1) + + return _fact(n) + + # Step 1: Prefix each FunctionToolset individually + prefixed_math = PrefixedToolset(math_toolset, 'math') + prefixed_string = PrefixedToolset(string_toolset, 'str') + prefixed_advanced = PrefixedToolset(advanced_toolset, 'adv') + + # Step 2: Combine the prefixed toolsets + combined_prefixed_toolset = CombinedToolset([prefixed_math, prefixed_string, prefixed_advanced]) + + # Step 3: Filter tools based on user role and advanced flag, now using prefixed names + def filter_tools(ctx: RunContext[TestDeps], tool_def: ToolDefinition) -> bool: + # Only allow advanced tools if enable_advanced is True + if tool_def.name.startswith('adv_') and not ctx.deps.enable_advanced: + return False + # Only allow string operations for admin users (simulating role-based access) + if tool_def.name.startswith('str_') and ctx.deps.user_role != 'admin': + return False + return True + + filtered_toolset = FilteredToolset(combined_prefixed_toolset, filter_tools) + + # Step 4: Apply prepared toolset to modify descriptions (add user role annotation) + async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Annotate each tool description with the user role + role = ctx.deps.user_role + return [replace(td, description=f'{td.description} (role: {role})') for td in tool_defs] + + prepared_toolset = PreparedToolset(filtered_toolset, prepare_add_context) + + # Step 5: Apply processed toolset to add logging (store on deps.log, optionally wrap result) + async def process_with_logging( + ctx: RunContext[TestDeps], + call_tool_func: Callable[[str, dict[str, Any], Any], Awaitable[Any]], + name: str, + tool_args: dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> Any: + if ctx.deps.log_calls: + ctx.deps.log.append(f'Calling tool: {name} with args: {tool_args}') + result = await call_tool_func(name, tool_args, *args, **kwargs) + if ctx.deps.log_calls: + ctx.deps.log.append(f'Tool {name} returned: {result}') + # For demonstration, wrap the result in a dict if logging is enabled + return {'result': result} + return result + + processed_toolset = ProcessedToolset(prepared_toolset, process_with_logging) + + # Step 6: Test the fully composed toolset + # Test with regular user context (log_calls=False) + regular_deps = TestDeps(user_role='user', enable_advanced=True, log_calls=False) + regular_context = build_run_context(regular_deps) + final_toolset = await processed_toolset.prepare_for_run(regular_context) + # Tool definitions should have role annotation + assert final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='adv_power', + description='Calculate base raised to the power of exponent (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'base': {'type': 'integer'}, 'exponent': {'type': 'integer'}}, + 'required': ['base', 'exponent'], + 'type': 'object', + }, + ), + ToolDefinition( + name='adv_factorial', + description='Calculate factorial of n (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'n': {'type': 'integer'}}, + 'required': ['n'], + 'type': 'object', + }, + ), + ] + ) + # Call a tool and check result + result = await final_toolset.call_tool(regular_context, 'math_add', {'a': 5, 'b': 3}) + assert result == 8 + + # Test with admin user context (log_calls=False, should have string tools) + admin_deps = TestDeps(user_role='admin', enable_advanced=True, log_calls=False) + admin_context = build_run_context(admin_deps) + admin_final_toolset = await processed_toolset.prepare_for_run(admin_context) + assert admin_final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_concat', + description='Concatenate two strings (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'s1': {'type': 'string'}, 's2': {'type': 'string'}}, + 'required': ['s1', 's2'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_uppercase', + description='Convert text to uppercase (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'text': {'type': 'string'}}, + 'required': ['text'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_reverse', + description='Reverse a string (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'text': {'type': 'string'}}, + 'required': ['text'], + 'type': 'object', + }, + ), + ToolDefinition( + name='adv_power', + description='Calculate base raised to the power of exponent (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'base': {'type': 'integer'}, 'exponent': {'type': 'integer'}}, + 'required': ['base', 'exponent'], + 'type': 'object', + }, + ), + ToolDefinition( + name='adv_factorial', + description='Calculate factorial of n (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'n': {'type': 'integer'}}, + 'required': ['n'], + 'type': 'object', + }, + ), + ] + ) + result = await admin_final_toolset.call_tool(admin_context, 'str_concat', {'s1': 'Hello', 's2': 'World'}) + assert result == 'HelloWorld' + + # Test with logging enabled (log_calls=True, result should be wrapped) + logging_deps = TestDeps(user_role='admin', enable_advanced=True, log_calls=True) + logging_context = build_run_context(logging_deps) + logging_final_toolset = await processed_toolset.prepare_for_run(logging_context) + result = await logging_final_toolset.call_tool(logging_context, 'math_add', {'a': 10, 'b': 20}) + assert result == {'result': 30} + assert logging_deps.log == ["Calling tool: math_add with args: {'a': 10, 'b': 20}", 'Tool math_add returned: 30'] + + # Test with advanced features disabled (log_calls=False) + basic_deps = TestDeps(user_role='user', enable_advanced=False, log_calls=False) + basic_context = build_run_context(basic_deps) + basic_final_toolset = await processed_toolset.prepare_for_run(basic_context) + assert basic_final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ] + ) + + # Test prepare_for_run idempotency + # toolset.prepare_for_run(ctx1).prepare_for_run(ctx2) == toolset.prepare_for_run(ctx2) + ctx1 = build_run_context(TestDeps(user_role='user', enable_advanced=True, log_calls=False)) + ctx2 = build_run_context(TestDeps(user_role='admin', enable_advanced=True, log_calls=False)) + toolset_once = await processed_toolset.prepare_for_run(ctx2) + toolset_twice = await (await processed_toolset.prepare_for_run(ctx1)).prepare_for_run(ctx2) + assert toolset_once == toolset_twice From 9f9ee556c3c989d719bf5a09a1f2c6a0b792b7b3 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 20:06:58 +0000 Subject: [PATCH 86/90] AbstractToolset.call_tool now takes a ToolCallPart --- docs/mcp/client.md | 6 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 12 +- pydantic_ai_slim/pydantic_ai/_output.py | 8 +- pydantic_ai_slim/pydantic_ai/agent.py | 13 +-- pydantic_ai_slim/pydantic_ai/mcp.py | 99 +++++++++------- pydantic_ai_slim/pydantic_ai/result.py | 14 +-- .../pydantic_ai/toolsets/__init__.py | 27 +---- .../pydantic_ai/toolsets/_callable.py | 37 ++++++ .../toolsets/_individually_prepared.py | 6 +- .../pydantic_ai/toolsets/_mapped.py | 19 +--- pydantic_ai_slim/pydantic_ai/toolsets/_run.py | 77 +++++-------- .../pydantic_ai/toolsets/combined.py | 22 +--- .../pydantic_ai/toolsets/deferred.py | 10 +- .../pydantic_ai/toolsets/function.py | 14 +-- .../pydantic_ai/toolsets/prefixed.py | 13 +-- .../pydantic_ai/toolsets/processed.py | 44 ------- .../pydantic_ai/toolsets/wrapper.py | 18 +-- tests/test_examples.py | 8 +- tests/test_mcp.py | 21 ++-- tests/test_toolset.py | 107 +++++++----------- 20 files changed, 228 insertions(+), 347 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/_callable.py delete mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/processed.py diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 2c611b927..180dcfeb5 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -189,10 +189,12 @@ async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, name: str, - tool_args: dict[str, Any], + tool_args: str | dict[str, Any] None, + *args: Any, + **kwargs: Any ) -> ToolResult: """A tool call processor that passes along the deps.""" - return await call_tool(name, tool_args, metadata={'deps': ctx.deps}) + return await call_tool(name, tool_args, *args, metadata={'deps': ctx.deps}, **kwargs) server = MCPServerStdio('python', ['mcp_server.py'], process_tool_call=process_tool_call) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 999f1a5f5..e4fd9f38d 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -618,7 +618,7 @@ async def process_function_tools( # noqa: C901 output_parts.append(part) else: try: - result_data = await _call_tool(toolset, call, run_context) + result_data = await toolset.call_tool(call, run_context) except exceptions.UnexpectedModelBehavior as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) raise e @@ -755,7 +755,7 @@ async def _call_function_tool( with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: try: - tool_result = await _call_tool(toolset, tool_call, run_context) + tool_result = await toolset.call_tool(tool_call, run_context) except ToolRetryError as e: part = e.tool_retry if include_content and span.is_recording(): @@ -827,14 +827,6 @@ def process_content(content: Any) -> Any: return (part, extra_parts) -async def _call_tool( - toolset: AbstractToolset[DepsT], tool_call: _messages.ToolCallPart, run_context: RunContext[DepsT] -) -> Any: - run_context = dataclasses.replace(run_context, tool_call_id=tool_call.tool_call_id) - args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args) - return await toolset.call_tool(run_context, tool_call.tool_name, args_dict) - - @dataclasses.dataclass class _RunMessages: messages: list[_messages.ModelMessage] diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 66789c072..bc2102eaf 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -29,7 +29,7 @@ _OutputSpecItem, # type: ignore[reportPrivateUsage] ) from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition -from .toolsets import AbstractToolset +from .toolsets._callable import CallableToolset from .toolsets._run import RunToolset if TYPE_CHECKING: @@ -829,7 +829,7 @@ async def process( @dataclass(init=False) -class OutputToolset(AbstractToolset[AgentDepsT]): +class OutputToolset(CallableToolset[AgentDepsT]): """A toolset that contains output tools.""" _tool_defs: list[ToolDefinition] @@ -930,9 +930,7 @@ def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> Sc def _max_retries_for_tool(self, name: str) -> int: return self.max_retries - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: + async def _call_tool(self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any]) -> Any: output = await self.processors[name].call(tool_args, ctx) for validator in self.output_validators: output = await validator.validate(output, ctx, wrap_validation_errors=False) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index f2e351bd9..c10dd642c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -166,7 +166,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) - _running_count: int = dataclasses.field(repr=False) + _entered_count: int = dataclasses.field(repr=False) _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) @overload @@ -430,7 +430,7 @@ def __init__( self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) self._exit_stack = None - self._running_count = 0 + self._entered_count = 0 @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: @@ -1788,18 +1788,17 @@ def is_end_node( async def __aenter__(self) -> Self: """Enter the agent. This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered with the agent so they can be used in a run.""" - if self._running_count == 0: + if self._entered_count == 0: self._exit_stack = AsyncExitStack() await self._exit_stack.enter_async_context(self._toolset) - self._running_count += 1 + self._entered_count += 1 return self async def __aexit__(self, *args: Any) -> bool | None: - self._running_count -= 1 - if self._running_count <= 0 and self._exit_stack is not None: + self._entered_count -= 1 + if self._entered_count <= 0 and self._exit_stack is not None: await self._exit_stack.aclose() self._exit_stack = None - return None def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None: """Set the sampling model on all MCP servers registered with the agent. diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index ce581a6fe..e1878e3f7 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -3,11 +3,10 @@ import base64 import functools from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from dataclasses import dataclass from pathlib import Path -from types import TracebackType from typing import Any, Callable import anyio @@ -19,10 +18,9 @@ from pydantic_ai._run_context import RunContext from pydantic_ai.tools import ToolDefinition -from .toolsets import AbstractToolset +from .toolsets._callable import CallableToolset from .toolsets._run import RunToolset from .toolsets.prefixed import PrefixedToolset -from .toolsets.processed import ProcessedToolset, ToolProcessFunc try: from mcp import types as mcp_types @@ -45,7 +43,7 @@ __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP' -class MCPServer(AbstractToolset[Any], ABC): +class MCPServer(CallableToolset[Any], ABC): """Base class for attaching agents to MCP servers. See for more information. @@ -56,7 +54,7 @@ class MCPServer(AbstractToolset[Any], ABC): log_level: mcp_types.LoggingLevel | None = None log_handler: LoggingFnT | None = None timeout: float = 5 - process_tool_call: ToolProcessFunc[Any] | None = None + process_tool_call: ProcessToolCallback | None = None allow_sampling: bool = True max_retries: int = 1 sampling_model: models.Model | None = None @@ -102,14 +100,11 @@ async def list_tools(self) -> list[mcp_types.Tool]: result = await self._client.list_tools() return result.tools - async def call_tool( + async def _call_tool( self, ctx: RunContext[Any], name: str, tool_args: dict[str, Any], - *args: Any, - metadata: dict[str, Any] | None = None, - **kwargs: Any, ) -> ToolResult: """Call a tool on the server. @@ -127,36 +122,41 @@ async def call_tool( Raises: ModelRetry: If the tool call fails. """ - async with self: # Ensure server is running - try: - result = await self._client.send_request( - mcp_types.ClientRequest( - mcp_types.CallToolRequest( - method='tools/call', - params=mcp_types.CallToolRequestParams( - name=name, - arguments=tool_args, - _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, - ), - ) - ), - mcp_types.CallToolResult, - ) - except McpError as e: - raise exceptions.ModelRetry(e.error.message) - content = [self._map_tool_result_part(part) for part in result.content] + async def _call(name: str, args: dict[str, Any], metadata: dict[str, Any] | None = None) -> ToolResult: + async with self: # Ensure server is running + try: + result = await self._client.send_request( + mcp_types.ClientRequest( + mcp_types.CallToolRequest( + method='tools/call', + params=mcp_types.CallToolRequestParams( + name=name, + arguments=args, + _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, + ), + ) + ), + mcp_types.CallToolResult, + ) + except McpError as e: + raise exceptions.ModelRetry(e.error.message) + + content = [self._map_tool_result_part(part) for part in result.content] + + if result.isError: + text = '\n'.join(str(part) for part in content) + raise exceptions.ModelRetry(text) + else: + return content[0] if len(content) == 1 else content - if result.isError: - text = '\n'.join(str(part) for part in content) - raise exceptions.ModelRetry(text) + if self.process_tool_call is not None: + return await self.process_tool_call(ctx, _call, name, tool_args) else: - return content[0] if len(content) == 1 else content + return await _call(name, tool_args) async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]: frozen_toolset = RunToolset(self, ctx, await self.list_tool_defs()) - if self.process_tool_call: - frozen_toolset = await ProcessedToolset(frozen_toolset, self.process_tool_call).prepare_for_run(ctx) if self.tool_prefix: frozen_toolset = await PrefixedToolset(frozen_toolset, self.tool_prefix).prepare_for_run(ctx) return RunToolset(frozen_toolset, ctx, original=self) @@ -208,12 +208,7 @@ async def __aenter__(self) -> Self: self._running_count += 1 return self - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> bool | None: + async def __aexit__(self, *args: Any) -> bool | None: self._running_count -= 1 if self._running_count <= 0: await self._exit_stack.aclose() @@ -364,7 +359,7 @@ async def main(): timeout: float = 5 """The timeout in seconds to wait for the client to initialize.""" - process_tool_call: ToolProcessFunc[Any] | None = None + process_tool_call: ProcessToolCallback | None = None """Hook to customize tool calling and optionally pass extra metadata.""" allow_sampling: bool = True @@ -465,7 +460,7 @@ class _MCPServerHTTP(MCPServer): If the connection cannot be established within this time, the operation will fail. """ - process_tool_call: ToolProcessFunc[Any] | None = None + process_tool_call: ProcessToolCallback | None = None """Hook to customize tool calling and optionally pass extra metadata.""" allow_sampling: bool = True @@ -642,3 +637,23 @@ def _transport_client(self): | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] ) """The result type of an MCP tool call.""" + +CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]] +"""A function type that represents a tool call.""" + +ProcessToolCallback = Callable[ + [ + RunContext[Any], + CallToolFunc, + str, + dict[str, Any], + ], + Awaitable[ToolResult], +] +"""A process tool callback. + +It accepts a run context, the original tool call function, a tool name, and arguments. + +Allows wrapping an MCP server tool call to customize it, including adding extra request +metadata. +""" diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 8615a8f2e..49d5db6e7 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -3,7 +3,7 @@ import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from copy import copy -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field from datetime import datetime from typing import Generic, cast @@ -106,11 +106,7 @@ async def _validate_response( raise exceptions.UnexpectedModelBehavior( # pragma: no cover f'Invalid response, unable to find tool call for {output_tool_name!r}' ) - run_context = replace(self._run_ctx, tool_call_id=tool_call.tool_call_id) - args_dict = self._toolset.validate_tool_args( - run_context, tool_call.tool_name, tool_call.args, allow_partial=allow_partial - ) - return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict) + return await self._toolset.call_tool(tool_call, self._run_ctx, allow_partial=allow_partial) elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): if not self._output_schema.allows_deferred_tool_calls: raise exceptions.UserError( # pragma: no cover @@ -442,11 +438,7 @@ async def validate_structured_output( raise exceptions.UnexpectedModelBehavior( # pragma: no cover f'Invalid response, unable to find tool call for {self._output_tool_name!r}' ) - run_context = replace(self._run_ctx, tool_call_id=tool_call.tool_call_id) - args_dict = self._toolset.validate_tool_args( - run_context, tool_call.tool_name, tool_call.args, allow_partial=allow_partial - ) - return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict) + return await self._toolset.call_tool(tool_call, self._run_ctx, allow_partial=allow_partial) elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): if not self._output_schema.allows_deferred_tool_calls: raise exceptions.UserError( diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py index 65293ca17..f94a57efa 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -1,13 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal +from typing import TYPE_CHECKING, Any, Callable, Generic -from pydantic_core import SchemaValidator from typing_extensions import Self from .._run_context import AgentDepsT, RunContext +from ..messages import ToolCallPart from ..tools import ToolDefinition if TYPE_CHECKING: @@ -34,9 +33,7 @@ def tool_name_conflict_hint(self) -> str: async def __aenter__(self) -> Self: return self - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None - ) -> bool | None: + async def __aexit__(self, *args: Any) -> bool | None: return None @abstractmethod @@ -55,28 +52,12 @@ def tool_names(self) -> list[str]: def get_tool_def(self, name: str) -> ToolDefinition | None: return next((tool_def for tool_def in self.tool_defs if tool_def.name == name), None) - @abstractmethod - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - raise NotImplementedError() - - def validate_tool_args( - self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False - ) -> dict[str, Any]: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - validator = self._get_tool_args_validator(ctx, name) - if isinstance(args, str): - return validator.validate_json(args or '{}', allow_partial=pyd_allow_partial) - else: - return validator.validate_python(args or {}, allow_partial=pyd_allow_partial) - @abstractmethod def _max_retries_for_tool(self, name: str) -> int: raise NotImplementedError() @abstractmethod - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: + async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: raise NotImplementedError() def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_callable.py b/pydantic_ai_slim/pydantic_ai/toolsets/_callable.py new file mode 100644 index 000000000..3b8562e82 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_callable.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import replace +from typing import TYPE_CHECKING, Any, Literal + +from pydantic_core import SchemaValidator + +from .._run_context import AgentDepsT, RunContext +from ..messages import ToolCallPart +from . import AbstractToolset + +if TYPE_CHECKING: + pass + + +class CallableToolset(AbstractToolset[AgentDepsT], ABC): + """A toolset that implements tool args validation and tool calling.""" + + @abstractmethod + def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator: + raise NotImplementedError() + + @abstractmethod + async def _call_tool(self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any]) -> Any: + raise NotImplementedError() + + async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: + ctx = replace(ctx, tool_name=call.tool_name, tool_call_id=call.tool_call_id) + + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + validator = self._get_tool_args_validator(ctx, call.tool_name) + if isinstance(call.args, str): + args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial) + else: + args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) + return await self._call_tool(ctx, call.tool_name, args_dict) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py index 88ccbd728..a232d709e 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_individually_prepared.py @@ -1,6 +1,6 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from .._run_context import AgentDepsT, RunContext from ..exceptions import UserError @@ -26,7 +26,9 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent name_map: dict[str, str] = {} for original_tool_def in wrapped_for_run.tool_defs: original_name = original_tool_def.name - tool_def = await self.prepare_func(ctx, original_tool_def) + + run_context = replace(ctx, tool_name=original_name, retry=ctx.retries.get(original_name, 0)) + tool_def = await self.prepare_func(run_context, original_tool_def) if not tool_def: continue diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py b/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py index 47a7a9e72..628917ae0 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_mapped.py @@ -1,14 +1,11 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import Any -from pydantic_core import SchemaValidator - from .._run_context import AgentDepsT, RunContext -from ..tools import ( - ToolDefinition, -) +from ..messages import ToolCallPart +from ..tools import ToolDefinition from . import AbstractToolset from ._run import RunToolset from .wrapper import WrapperToolset @@ -40,16 +37,12 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent def tool_defs(self) -> list[ToolDefinition]: return self._tool_defs - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return super()._get_tool_args_validator(ctx, self._map_name(name)) - def _max_retries_for_tool(self, name: str) -> int: return super()._max_retries_for_tool(self._map_name(name)) - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await super().call_tool(ctx, self._map_name(name), tool_args, *args, **kwargs) + async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: + call = replace(call, tool_name=self._map_name(call.tool_name)) + return await super().call_tool(call, ctx, allow_partial=allow_partial) def _map_name(self, name: str) -> str: return self.name_map.get(name, name) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_run.py b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py index a7da57999..1bc815cf0 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/_run.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_run.py @@ -1,7 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable, Iterator -from contextlib import contextmanager +from collections.abc import Iterable from dataclasses import dataclass, replace from typing import Any @@ -12,6 +11,7 @@ from .. import messages as _messages from .._run_context import AgentDepsT, RunContext from ..exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior +from ..messages import ToolCallPart from ..tools import ToolDefinition from . import AbstractToolset from .wrapper import WrapperToolset @@ -60,45 +60,8 @@ def tool_defs(self) -> list[ToolDefinition]: def tool_names(self) -> list[str]: return self._tool_names - def validate_tool_args( - self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False - ) -> dict[str, Any]: - with self._with_retry(name, ctx) as ctx: - return super().validate_tool_args(ctx, name, args, allow_partial) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - with self._with_retry(name, ctx) as ctx: - try: - output = await super().call_tool(ctx, name, tool_args, *args, **kwargs) - except Exception as e: - raise e - else: - self._retries.pop(name, None) - return output - - def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None: - deferred_calls_and_defs = [ - (part, tool_def) - for part in parts - if isinstance(part, _messages.ToolCallPart) - and (tool_def := self.get_tool_def(part.tool_name)) - and tool_def.kind == 'deferred' - ] - if not deferred_calls_and_defs: - return None - - deferred_calls: list[_messages.ToolCallPart] = [] - deferred_tool_defs: dict[str, ToolDefinition] = {} - for part, tool_def in deferred_calls_and_defs: - deferred_calls.append(part) - deferred_tool_defs[part.tool_name] = tool_def - - return DeferredToolCalls(deferred_calls, deferred_tool_defs) - - @contextmanager - def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunContext[AgentDepsT]]: + async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: + name = call.tool_name try: if name not in self.tool_names: if self.tool_names: @@ -107,8 +70,8 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon msg = 'No tools available.' raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') - ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={}) - yield ctx + ctx = replace(ctx, retry=self._retries.get(name, 0), retries={}) + output = await super().call_tool(call, ctx, allow_partial=allow_partial) except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e: try: max_retries = self._max_retries_for_tool(name) @@ -126,18 +89,38 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon m = _messages.RetryPromptPart( tool_name=name, content=e.errors(include_url=False, include_context=False), + tool_call_id=call.tool_call_id, ) - if ctx.tool_call_id: # pragma: no branch - m.tool_call_id = ctx.tool_call_id e = ToolRetryError(m) elif isinstance(e, ModelRetry): m = _messages.RetryPromptPart( tool_name=name, content=e.message, + tool_call_id=call.tool_call_id, ) - if ctx.tool_call_id: # pragma: no branch - m.tool_call_id = ctx.tool_call_id e = ToolRetryError(m) self._retries[name] = current_retry + 1 raise e + else: + self._retries.pop(name, None) + return output + + def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None: + deferred_calls_and_defs = [ + (part, tool_def) + for part in parts + if isinstance(part, _messages.ToolCallPart) + and (tool_def := self.get_tool_def(part.tool_name)) + and tool_def.kind == 'deferred' + ] + if not deferred_calls_and_defs: + return None + + deferred_calls: list[_messages.ToolCallPart] = [] + deferred_tool_defs: dict[str, ToolDefinition] = {} + for part, tool_def in deferred_calls_and_defs: + deferred_calls.append(part) + deferred_tool_defs[part.tool_name] = tool_def + + return DeferredToolCalls(deferred_calls, deferred_tool_defs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 6200603ce..ce49aca1d 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -4,14 +4,13 @@ from collections.abc import Sequence from contextlib import AsyncExitStack from dataclasses import dataclass -from types import TracebackType from typing import Any, Callable -from pydantic_core import SchemaValidator from typing_extensions import Self from .._run_context import AgentDepsT, RunContext from ..exceptions import UserError +from ..messages import ToolCallPart from ..tools import ToolDefinition from . import AbstractToolset from ._run import RunToolset @@ -51,14 +50,11 @@ async def __aenter__(self) -> Self: self._entered_count += 1 return self - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None - ) -> bool | None: + async def __aexit__(self, *args: Any) -> bool | None: self._entered_count -= 1 if self._entered_count <= 0 and self._exit_stack is not None: await self._exit_stack.aclose() self._exit_stack = None - return None async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: toolsets_for_run = await asyncio.gather(*[toolset.prepare_for_run(ctx) for toolset in self.toolsets]) @@ -73,21 +69,11 @@ def tool_defs(self) -> list[ToolDefinition]: def tool_names(self) -> list[str]: return list(self._toolset_per_tool_name.keys()) - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self._toolset_for_tool_name(name)._get_tool_args_validator(ctx, name) - - def validate_tool_args( - self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False - ) -> dict[str, Any]: - return self._toolset_for_tool_name(name).validate_tool_args(ctx, name, args, allow_partial) - def _max_retries_for_tool(self, name: str) -> int: return self._toolset_for_tool_name(name)._max_retries_for_tool(name) - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs) + async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: + return await self._toolset_for_tool_name(call.tool_name).call_tool(call, ctx, allow_partial=allow_partial) def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: for toolset in self.toolsets: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py index b6b1f8806..ac039182e 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -3,9 +3,8 @@ from dataclasses import replace from typing import Any -from pydantic_core import SchemaValidator - from .._run_context import AgentDepsT, RunContext +from ..messages import ToolCallPart from ..tools import ToolDefinition from . import AbstractToolset from ._run import RunToolset @@ -26,13 +25,8 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent def tool_defs(self) -> list[ToolDefinition]: return [replace(tool_def, kind='deferred') for tool_def in self._tool_defs] - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - raise NotImplementedError('Deferred tools cannot be validated') - def _max_retries_for_tool(self, name: str) -> int: raise NotImplementedError('Deferred tools cannot be retried') - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: + async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: raise NotImplementedError('Deferred tools cannot be called') diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index fbc60f8b0..efce3d920 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field from typing import Any, Callable, overload from pydantic.json_schema import GenerateJsonSchema @@ -18,13 +18,13 @@ ToolParams, ToolPrepareFunc, ) -from . import AbstractToolset +from ._callable import CallableToolset from ._individually_prepared import IndividuallyPreparedToolset from ._run import RunToolset @dataclass(init=False) -class FunctionToolset(AbstractToolset[AgentDepsT]): +class FunctionToolset(CallableToolset[AgentDepsT]): """A toolset that lets Python functions be used as tools.""" max_retries: int = field(default=1) @@ -187,9 +187,7 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent return RunToolset(prepared_for_run, ctx, original=self) async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None: - tool_name = tool_def.name - ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0)) - return await self.tools[tool_name].prepare_tool_def(ctx) + return await self.tools[tool_def.name].prepare_tool_def(ctx) @property def tool_defs(self) -> list[ToolDefinition]: @@ -202,7 +200,5 @@ def _max_retries_for_tool(self, name: str) -> int: tool = self.tools[name] return tool.max_retries if tool.max_retries is not None else self.max_retries - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: + async def _call_tool(self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any]) -> Any: return await self.tools[name].function_schema.call(tool_args, ctx) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py index 9210746ae..4fb78cf9a 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -3,9 +3,8 @@ from dataclasses import dataclass, replace from typing import Any -from pydantic_core import SchemaValidator - from .._run_context import AgentDepsT, RunContext +from ..messages import ToolCallPart from ..tools import ToolDefinition from ._run import RunToolset from .wrapper import WrapperToolset @@ -26,16 +25,12 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent def tool_defs(self) -> list[ToolDefinition]: return [replace(tool_def, name=self._prefixed_tool_name(tool_def.name)) for tool_def in super().tool_defs] - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return super()._get_tool_args_validator(ctx, self._unprefixed_tool_name(name)) - def _max_retries_for_tool(self, name: str) -> int: return super()._max_retries_for_tool(self._unprefixed_tool_name(name)) - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await super().call_tool(ctx, self._unprefixed_tool_name(name), tool_args, *args, **kwargs) + async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: + call = replace(call, tool_name=self._unprefixed_tool_name(call.tool_name)) + return await super().call_tool(call, ctx, allow_partial=allow_partial) def _prefixed_tool_name(self, tool_name: str) -> str: return f'{self.prefix}_{tool_name}' diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/processed.py b/pydantic_ai_slim/pydantic_ai/toolsets/processed.py deleted file mode 100644 index c63854f7b..000000000 --- a/pydantic_ai_slim/pydantic_ai/toolsets/processed.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -from collections.abc import Awaitable -from dataclasses import dataclass -from functools import partial -from typing import Any, Callable, Protocol - -from .._run_context import AgentDepsT, RunContext -from ._run import RunToolset -from .wrapper import WrapperToolset - - -class CallToolFunc(Protocol): - """A function protocol that represents a tool call.""" - - def __call__(self, name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any) -> Awaitable[Any]: ... - - -ToolProcessFunc = Callable[ - [ - RunContext[AgentDepsT], - CallToolFunc, - str, - dict[str, Any], - ], - Awaitable[Any], -] - - -@dataclass -class ProcessedToolset(WrapperToolset[AgentDepsT]): - """A toolset that lets the tool call arguments and return value be customized using a wrapper function.""" - - process: ToolProcessFunc[AgentDepsT] - - async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - wrapped_for_run = await self.wrapped.prepare_for_run(ctx) - processed = ProcessedToolset(wrapped_for_run, self.process) - return RunToolset(processed, ctx) - - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await self.process(ctx, partial(self.wrapped.call_tool, ctx), name, tool_args, *args, **kwargs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 71fcca012..3ae7ed3cc 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -2,13 +2,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from types import TracebackType from typing import TYPE_CHECKING, Any, Callable -from pydantic_core import SchemaValidator from typing_extensions import Self from .._run_context import AgentDepsT, RunContext +from ..messages import ToolCallPart from ..tools import ToolDefinition from . import AbstractToolset @@ -34,10 +33,8 @@ async def __aenter__(self) -> Self: await self.wrapped.__aenter__() return self - async def __aexit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None - ) -> bool | None: - return await self.wrapped.__aexit__(exc_type, exc_value, traceback) + async def __aexit__(self, *args: Any) -> bool | None: + return await self.wrapped.__aexit__(*args) @abstractmethod async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: @@ -47,16 +44,11 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent def tool_defs(self) -> list[ToolDefinition]: return self.wrapped.tool_defs - def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator: - return self.wrapped._get_tool_args_validator(ctx, name) - def _max_retries_for_tool(self, name: str) -> int: return self.wrapped._max_retries_for_tool(name) - async def call_tool( - self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: - return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs) + async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: + return await self.wrapped.call_tool(call, ctx, allow_partial=allow_partial) def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: return self.wrapped.accept(visitor) diff --git a/tests/test_examples.py b/tests/test_examples.py index 9763e8f57..5fa7a8adf 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -16,7 +16,6 @@ import pytest from _pytest.mark import ParameterSet from devtools import debug -from pydantic_core import SchemaValidator, core_schema from pytest_examples import CodeExample, EvalExample, find_examples from pytest_mock import MockerFixture from rich.console import Console @@ -275,15 +274,10 @@ async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]: def tool_defs(self) -> list[ToolDefinition]: return [] - def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator: - return SchemaValidator(core_schema.any_schema()) # pragma: lax no cover - def _max_retries_for_tool(self, name: str) -> int: return 0 # pragma: lax no cover - async def call_tool( - self, ctx: RunContext[Any], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any - ) -> Any: + async def call_tool(self, call: ToolCallPart, ctx: RunContext[Any], allow_partial: bool = False) -> Any: return None # pragma: lax no cover diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 2394fdd13..1037e5a21 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -28,7 +28,6 @@ from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext -from pydantic_ai.toolsets.processed import CallToolFunc from pydantic_ai.usage import Usage from .conftest import IsDatetime, IsNow, IsStr, try_import @@ -38,7 +37,7 @@ from mcp.types import CreateMessageRequestParams, ImageContent, TextContent from pydantic_ai._mcp import map_from_mcp_params, map_from_model_response - from pydantic_ai.mcp import MCPServerSSE, MCPServerStdio, ToolResult + from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult from pydantic_ai.models.google import GoogleModel from pydantic_ai.models.openai import OpenAIModel from pydantic_ai.providers.google import GoogleProvider @@ -80,7 +79,9 @@ async def test_stdio_server(run_context: RunContext[int]): assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') # Test calling the temperature conversion tool - result = await server.call_tool(run_context, 'celsius_to_fahrenheit', {'celsius': 0}) + result = await server.call_tool( + ToolCallPart(tool_name='celsius_to_fahrenheit', args={'celsius': 0}), run_context + ) assert result == snapshot('32.0') @@ -113,12 +114,12 @@ async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, name: str, - args: dict[str, Any], + tool_args: dict[str, Any], ) -> ToolResult: """A process_tool_call that sets a flag and sends deps as metadata.""" nonlocal called called = True - return await call_tool(name, args, metadata={'deps': ctx.deps}) + return await call_tool(name, tool_args, {'deps': ctx.deps}) server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call) async with server: @@ -271,7 +272,7 @@ async def test_log_level_unset(run_context: RunContext[int]): assert len(tools) == snapshot(13) assert tools[10].name == 'get_log_level' - result = await server.call_tool(run_context, 'get_log_level', {}) + result = await server.call_tool(ToolCallPart(tool_name='get_log_level', args={}), run_context) assert result == snapshot('unset') @@ -279,7 +280,7 @@ async def test_log_level_set(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], log_level='info') assert server.log_level == 'info' async with server: - result = await server.call_tool(run_context, 'get_log_level', {}) + result = await server.call_tool(ToolCallPart(tool_name='get_log_level', args={}), run_context) assert result == snapshot('info') @@ -992,7 +993,7 @@ async def test_client_sampling(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server.sampling_model = TestModel(custom_output_text='sampling model response') async with server: - result = await server.call_tool(run_context, 'use_sampling', {'foo': 'bar'}) + result = await server.call_tool(ToolCallPart(tool_name='use_sampling', args={'foo': 'bar'}), run_context) assert result == snapshot( { 'meta': None, @@ -1009,7 +1010,7 @@ async def test_client_sampling_disabled(run_context: RunContext[int]): server.sampling_model = TestModel(custom_output_text='sampling model response') async with server: with pytest.raises(ModelRetry, match='Error executing tool use_sampling: Sampling not supported'): - await server.call_tool(run_context, 'use_sampling', {'foo': 'bar'}) + await server.call_tool(ToolCallPart(tool_name='use_sampling', args={'foo': 'bar'}), run_context) async def test_mcp_server_raises_mcp_error( @@ -1024,7 +1025,7 @@ async def test_mcp_server_raises_mcp_error( new=AsyncMock(side_effect=mcp_error), ): with pytest.raises(ModelRetry, match='Test MCP error conversion'): - await mcp_server.call_tool(run_context, 'test_tool', {}) + await mcp_server.call_tool(ToolCallPart(tool_name='test_tool', args={}), run_context) def test_map_from_mcp_params_model_request(): diff --git a/tests/test_toolset.py b/tests/test_toolset.py index a0e4bb882..15dc01a7d 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolset.py @@ -1,14 +1,14 @@ from __future__ import annotations -from collections.abc import Awaitable -from dataclasses import dataclass, field, replace -from typing import Any, Callable, TypeVar +from dataclasses import dataclass, replace +from typing import TypeVar import pytest from inline_snapshot import snapshot from pydantic_ai._run_context import RunContext from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ToolCallPart from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets.combined import CombinedToolset @@ -16,7 +16,6 @@ from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.toolsets.prefixed import PrefixedToolset from pydantic_ai.toolsets.prepared import PreparedToolset -from pydantic_ai.toolsets.processed import ProcessedToolset from pydantic_ai.usage import Usage pytestmark = pytest.mark.anyio @@ -69,13 +68,15 @@ def add(a: int, b: int) -> int: ) ] ) - assert await toolset.call_tool(context, 'add', {'a': 1, 'b': 2}) == 3 + assert await toolset.call_tool(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2}), context) == 3 no_prefix_context = build_run_context(PrefixDeps()) no_prefix_toolset = await toolset.prepare_for_run(no_prefix_context) assert no_prefix_toolset.tool_names == toolset.tool_names assert no_prefix_toolset.tool_defs == toolset.tool_defs - assert await no_prefix_toolset.call_tool(no_prefix_context, 'add', {'a': 1, 'b': 2}) == 3 + assert ( + await no_prefix_toolset.call_tool(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2}), no_prefix_context) == 3 + ) foo_context = build_run_context(PrefixDeps(prefix='foo')) foo_toolset = await toolset.prepare_for_run(foo_context) @@ -94,7 +95,7 @@ def add(a: int, b: int) -> int: ) ] ) - assert await foo_toolset.call_tool(foo_context, 'foo_add', {'a': 1, 'b': 2}) == 3 + assert await foo_toolset.call_tool(ToolCallPart(tool_name='foo_add', args={'a': 1, 'b': 2}), foo_context) == 3 @toolset.tool def subtract(a: int, b: int) -> int: @@ -130,7 +131,7 @@ def subtract(a: int, b: int) -> int: ), ] ) - assert await bar_toolset.call_tool(bar_context, 'bar_add', {'a': 1, 'b': 2}) == 3 + assert await bar_toolset.call_tool(ToolCallPart(tool_name='bar_add', args={'a': 1, 'b': 2}), bar_context) == 3 bar_foo_toolset = await foo_toolset.prepare_for_run(bar_context) assert bar_foo_toolset == bar_toolset @@ -226,8 +227,8 @@ async def prepare_remove_tools(ctx: RunContext[None], tool_defs: list[ToolDefini assert len(run_toolset.tool_defs) == 2 # Verify that the tools still work - assert await run_toolset.call_tool(context, 'add', {'a': 5, 'b': 3}) == 8 - assert await run_toolset.call_tool(context, 'multiply', {'a': 4, 'b': 2}) == 8 + assert await run_toolset.call_tool(ToolCallPart(tool_name='add', args={'a': 5, 'b': 3}), context) == 8 + assert await run_toolset.call_tool(ToolCallPart(tool_name='multiply', args={'a': 4, 'b': 2}), context) == 8 async def test_prefixed_toolset_tool_defs(): @@ -262,7 +263,7 @@ def subtract(a: int, b: int) -> int: assert subtract_def.description == 'Subtract two numbers' -async def test_prefixed_toolset_call_tools(): +async def test_prefixed_toolsetcall_tools(): """Test that PrefixedToolset correctly calls tools with prefixed names.""" context = build_run_context(None) base_toolset = FunctionToolset[None]() @@ -280,10 +281,10 @@ def multiply(a: int, b: int) -> int: prefixed_toolset = PrefixedToolset(base_toolset, 'calc') # Test calling tools with prefixed names - result = await prefixed_toolset.call_tool(context, 'calc_add', {'a': 5, 'b': 3}) + result = await prefixed_toolset.call_tool(ToolCallPart(tool_name='calc_add', args={'a': 5, 'b': 3}), context) assert result == 8 - result = await prefixed_toolset.call_tool(context, 'calc_multiply', {'a': 4, 'b': 2}) + result = await prefixed_toolset.call_tool(ToolCallPart(tool_name='calc_multiply', args={'a': 4, 'b': 2}), context) assert result == 8 @@ -308,7 +309,7 @@ def add(a: int, b: int) -> int: assert run_toolset.tool_defs[0].name == 'test_add' # Verify that the tool still works - result = await run_toolset.call_tool(context, 'test_add', {'a': 10, 'b': 5}) + result = await run_toolset.call_tool(ToolCallPart(tool_name='test_add', args={'a': 10, 'b': 5}), context) assert result == 15 @@ -326,15 +327,15 @@ def add(a: int, b: int) -> int: # Test calling with wrong prefix with pytest.raises(ValueError, match="Tool name 'wrong_add' does not start with prefix 'math_'"): - await prefixed_toolset.call_tool(context, 'wrong_add', {'a': 1, 'b': 2}) + await prefixed_toolset.call_tool(ToolCallPart(tool_name='wrong_add', args={'a': 1, 'b': 2}), context) # Test calling with no prefix with pytest.raises(ValueError, match="Tool name 'add' does not start with prefix 'math_'"): - await prefixed_toolset.call_tool(context, 'add', {'a': 1, 'b': 2}) + await prefixed_toolset.call_tool(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2}), context) # Test calling with partial prefix with pytest.raises(ValueError, match="Tool name 'mat_add' does not start with prefix 'math_'"): - await prefixed_toolset.call_tool(context, 'mat_add', {'a': 1, 'b': 2}) + await prefixed_toolset.call_tool(ToolCallPart(tool_name='mat_add', args={'a': 1, 'b': 2}), context) async def test_prefixed_toolset_empty_prefix(): @@ -353,23 +354,21 @@ def add(a: int, b: int) -> int: assert prefixed_toolset.tool_names == ['_add'] # Test calling the tool - result = await prefixed_toolset.call_tool(context, '_add', {'a': 3, 'b': 4}) + result = await prefixed_toolset.call_tool(ToolCallPart(tool_name='_add', args={'a': 3, 'b': 4}), context) assert result == 7 # Test error for wrong name with pytest.raises(ValueError, match="Tool name 'add' does not start with prefix '_'"): - await prefixed_toolset.call_tool(context, 'add', {'a': 1, 'b': 2}) + await prefixed_toolset.call_tool(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2}), context) -async def test_comprehensive_toolset_composition(): # noqa: C901 +async def test_comprehensive_toolset_composition(): """Test that all toolsets can be composed together and work correctly.""" @dataclass class TestDeps: user_role: str = 'user' enable_advanced: bool = True - log_calls: bool = False - log: list[str] = field(default_factory=list) # Create first FunctionToolset with basic math operations math_toolset = FunctionToolset[TestDeps]() @@ -444,7 +443,7 @@ def filter_tools(ctx: RunContext[TestDeps], tool_def: ToolDefinition) -> bool: return False return True - filtered_toolset = FilteredToolset(combined_prefixed_toolset, filter_tools) + filtered_toolset = FilteredToolset[TestDeps](combined_prefixed_toolset, filter_tools) # Step 4: Apply prepared toolset to modify descriptions (add user role annotation) async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: @@ -454,31 +453,11 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef prepared_toolset = PreparedToolset(filtered_toolset, prepare_add_context) - # Step 5: Apply processed toolset to add logging (store on deps.log, optionally wrap result) - async def process_with_logging( - ctx: RunContext[TestDeps], - call_tool_func: Callable[[str, dict[str, Any], Any], Awaitable[Any]], - name: str, - tool_args: dict[str, Any], - *args: Any, - **kwargs: Any, - ) -> Any: - if ctx.deps.log_calls: - ctx.deps.log.append(f'Calling tool: {name} with args: {tool_args}') - result = await call_tool_func(name, tool_args, *args, **kwargs) - if ctx.deps.log_calls: - ctx.deps.log.append(f'Tool {name} returned: {result}') - # For demonstration, wrap the result in a dict if logging is enabled - return {'result': result} - return result - - processed_toolset = ProcessedToolset(prepared_toolset, process_with_logging) - - # Step 6: Test the fully composed toolset - # Test with regular user context (log_calls=False) - regular_deps = TestDeps(user_role='user', enable_advanced=True, log_calls=False) + # Step 5: Test the fully composed toolset + # Test with regular user context + regular_deps = TestDeps(user_role='user', enable_advanced=True) regular_context = build_run_context(regular_deps) - final_toolset = await processed_toolset.prepare_for_run(regular_context) + final_toolset = await prepared_toolset.prepare_for_run(regular_context) # Tool definitions should have role annotation assert final_toolset.tool_defs == snapshot( [ @@ -535,13 +514,13 @@ async def process_with_logging( ] ) # Call a tool and check result - result = await final_toolset.call_tool(regular_context, 'math_add', {'a': 5, 'b': 3}) + result = await final_toolset.call_tool(ToolCallPart(tool_name='math_add', args={'a': 5, 'b': 3}), regular_context) assert result == 8 - # Test with admin user context (log_calls=False, should have string tools) - admin_deps = TestDeps(user_role='admin', enable_advanced=True, log_calls=False) + # Test with admin user context (should have string tools) + admin_deps = TestDeps(user_role='admin', enable_advanced=True) admin_context = build_run_context(admin_deps) - admin_final_toolset = await processed_toolset.prepare_for_run(admin_context) + admin_final_toolset = await prepared_toolset.prepare_for_run(admin_context) assert admin_final_toolset.tool_defs == snapshot( [ ToolDefinition( @@ -626,21 +605,15 @@ async def process_with_logging( ), ] ) - result = await admin_final_toolset.call_tool(admin_context, 'str_concat', {'s1': 'Hello', 's2': 'World'}) + result = await admin_final_toolset.call_tool( + ToolCallPart(tool_name='str_concat', args={'s1': 'Hello', 's2': 'World'}), admin_context + ) assert result == 'HelloWorld' - # Test with logging enabled (log_calls=True, result should be wrapped) - logging_deps = TestDeps(user_role='admin', enable_advanced=True, log_calls=True) - logging_context = build_run_context(logging_deps) - logging_final_toolset = await processed_toolset.prepare_for_run(logging_context) - result = await logging_final_toolset.call_tool(logging_context, 'math_add', {'a': 10, 'b': 20}) - assert result == {'result': 30} - assert logging_deps.log == ["Calling tool: math_add with args: {'a': 10, 'b': 20}", 'Tool math_add returned: 30'] - - # Test with advanced features disabled (log_calls=False) - basic_deps = TestDeps(user_role='user', enable_advanced=False, log_calls=False) + # Test with advanced features disabled + basic_deps = TestDeps(user_role='user', enable_advanced=False) basic_context = build_run_context(basic_deps) - basic_final_toolset = await processed_toolset.prepare_for_run(basic_context) + basic_final_toolset = await prepared_toolset.prepare_for_run(basic_context) assert basic_final_toolset.tool_defs == snapshot( [ ToolDefinition( @@ -678,8 +651,8 @@ async def process_with_logging( # Test prepare_for_run idempotency # toolset.prepare_for_run(ctx1).prepare_for_run(ctx2) == toolset.prepare_for_run(ctx2) - ctx1 = build_run_context(TestDeps(user_role='user', enable_advanced=True, log_calls=False)) - ctx2 = build_run_context(TestDeps(user_role='admin', enable_advanced=True, log_calls=False)) - toolset_once = await processed_toolset.prepare_for_run(ctx2) - toolset_twice = await (await processed_toolset.prepare_for_run(ctx1)).prepare_for_run(ctx2) + ctx1 = build_run_context(TestDeps(user_role='user', enable_advanced=True)) + ctx2 = build_run_context(TestDeps(user_role='admin', enable_advanced=True)) + toolset_once = await prepared_toolset.prepare_for_run(ctx2) + toolset_twice = await (await prepared_toolset.prepare_for_run(ctx1)).prepare_for_run(ctx2) assert toolset_once == toolset_twice From a3c9a591c6f389ec1edec5a31a571248ba389e88 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 20:21:38 +0000 Subject: [PATCH 87/90] Fix MCP process_tool_call example --- docs/mcp/client.md | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 180dcfeb5..a12da956d 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -179,22 +179,19 @@ call needs. from typing import Any from pydantic_ai import Agent -from pydantic_ai.mcp import MCPServerStdio, ToolResult +from pydantic_ai.mcp import CallToolFunc, MCPServerStdio, ToolResult from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext -from pydantic_ai.toolsets.processed import CallToolFunc async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, name: str, - tool_args: str | dict[str, Any] None, - *args: Any, - **kwargs: Any + tool_args: dict[str, Any], ) -> ToolResult: """A tool call processor that passes along the deps.""" - return await call_tool(name, tool_args, *args, metadata={'deps': ctx.deps}, **kwargs) + return await call_tool(name, tool_args, {'deps': ctx.deps}) server = MCPServerStdio('python', ['mcp_server.py'], process_tool_call=process_tool_call) From 6eae653a5fc8e6ac97c3eefde97a2b8ab976e22d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 21:03:53 +0000 Subject: [PATCH 88/90] Fix test coverage --- .../pydantic_ai/toolsets/prefixed.py | 2 +- tests/{test_toolset.py => test_toolsets.py} | 218 +++--------------- 2 files changed, 28 insertions(+), 192 deletions(-) rename tests/{test_toolset.py => test_toolsets.py} (73%) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py index 4fb78cf9a..7baef9ed2 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -26,7 +26,7 @@ def tool_defs(self) -> list[ToolDefinition]: return [replace(tool_def, name=self._prefixed_tool_name(tool_def.name)) for tool_def in super().tool_defs] def _max_retries_for_tool(self, name: str) -> int: - return super()._max_retries_for_tool(self._unprefixed_tool_name(name)) + return super()._max_retries_for_tool(self._unprefixed_tool_name(name)) # pragma: no cover async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: call = replace(call, tool_name=self._unprefixed_tool_name(call.tool_name)) diff --git a/tests/test_toolset.py b/tests/test_toolsets.py similarity index 73% rename from tests/test_toolset.py rename to tests/test_toolsets.py index 15dc01a7d..0c18ca36a 100644 --- a/tests/test_toolset.py +++ b/tests/test_toolsets.py @@ -145,7 +145,7 @@ async def test_prepared_toolset_user_error_add_new_tools(): @base_toolset.tool def add(a: int, b: int) -> int: """Add two numbers""" - return a + b + return a + b # pragma: no cover async def prepare_add_new_tool(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: # Try to add a new tool that wasn't in the original set @@ -175,7 +175,7 @@ async def test_prepared_toolset_user_error_change_tool_names(): @base_toolset.tool def add(a: int, b: int) -> int: """Add two numbers""" - return a + b + return a + b # pragma: no cover async def prepare_change_names(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: # Try to change the name of an existing tool @@ -193,126 +193,6 @@ async def prepare_change_names(ctx: RunContext[None], tool_defs: list[ToolDefini await prepared_toolset.prepare_for_run(context) -async def test_prepared_toolset_allows_removing_tools(): - """Test that PreparedToolset allows removing tools from the original set.""" - context = build_run_context(None) - base_toolset = FunctionToolset[None]() - - @base_toolset.tool - def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - @base_toolset.tool - def subtract(a: int, b: int) -> int: - """Subtract two numbers""" - return a - b - - @base_toolset.tool - def multiply(a: int, b: int) -> int: - """Multiply two numbers""" - return a * b - - async def prepare_remove_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: - # Remove the 'subtract' tool, keep 'add' and 'multiply' - return [tool_def for tool_def in tool_defs if tool_def.name != 'subtract'] - - prepared_toolset = PreparedToolset(base_toolset, prepare_remove_tools) - - # This should not raise an error - run_toolset = await prepared_toolset.prepare_for_run(context) - - # Verify that only 'add' and 'multiply' tools are available - assert set(run_toolset.tool_names) == {'add', 'multiply'} - assert len(run_toolset.tool_defs) == 2 - - # Verify that the tools still work - assert await run_toolset.call_tool(ToolCallPart(tool_name='add', args={'a': 5, 'b': 3}), context) == 8 - assert await run_toolset.call_tool(ToolCallPart(tool_name='multiply', args={'a': 4, 'b': 2}), context) == 8 - - -async def test_prefixed_toolset_tool_defs(): - """Test that PrefixedToolset correctly prefixes tool definitions.""" - base_toolset = FunctionToolset[None]() - - @base_toolset.tool - def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - @base_toolset.tool - def subtract(a: int, b: int) -> int: - """Subtract two numbers""" - return a - b - - prefixed_toolset = PrefixedToolset(base_toolset, 'math') - - # Check that tool names are prefixed - assert prefixed_toolset.tool_names == ['math_add', 'math_subtract'] - - # Check that tool definitions have prefixed names - tool_defs = prefixed_toolset.tool_defs - assert len(tool_defs) == 2 - - add_def = next(td for td in tool_defs if td.name == 'math_add') - subtract_def = next(td for td in tool_defs if td.name == 'math_subtract') - - assert add_def.name == 'math_add' - assert add_def.description == 'Add two numbers' - assert subtract_def.name == 'math_subtract' - assert subtract_def.description == 'Subtract two numbers' - - -async def test_prefixed_toolsetcall_tools(): - """Test that PrefixedToolset correctly calls tools with prefixed names.""" - context = build_run_context(None) - base_toolset = FunctionToolset[None]() - - @base_toolset.tool - def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - @base_toolset.tool - def multiply(a: int, b: int) -> int: - """Multiply two numbers""" - return a * b - - prefixed_toolset = PrefixedToolset(base_toolset, 'calc') - - # Test calling tools with prefixed names - result = await prefixed_toolset.call_tool(ToolCallPart(tool_name='calc_add', args={'a': 5, 'b': 3}), context) - assert result == 8 - - result = await prefixed_toolset.call_tool(ToolCallPart(tool_name='calc_multiply', args={'a': 4, 'b': 2}), context) - assert result == 8 - - -async def test_prefixed_toolset_prepare_for_run(): - """Test that PrefixedToolset correctly prepares for run with prefixed tools.""" - context = build_run_context(None) - base_toolset = FunctionToolset[None]() - - @base_toolset.tool - def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - prefixed_toolset = PrefixedToolset(base_toolset, 'test') - - # Prepare for run - run_toolset = await prefixed_toolset.prepare_for_run(context) - - # Verify that the run toolset has prefixed tools - assert run_toolset.tool_names == ['test_add'] - assert len(run_toolset.tool_defs) == 1 - assert run_toolset.tool_defs[0].name == 'test_add' - - # Verify that the tool still works - result = await run_toolset.call_tool(ToolCallPart(tool_name='test_add', args={'a': 10, 'b': 5}), context) - assert result == 15 - - async def test_prefixed_toolset_error_invalid_prefix(): """Test that PrefixedToolset raises ValueError for tool names that don't start with the prefix.""" context = build_run_context(None) @@ -321,7 +201,7 @@ async def test_prefixed_toolset_error_invalid_prefix(): @base_toolset.tool def add(a: int, b: int) -> int: """Add two numbers""" - return a + b + return a + b # pragma: no cover prefixed_toolset = PrefixedToolset(base_toolset, 'math') @@ -329,38 +209,6 @@ def add(a: int, b: int) -> int: with pytest.raises(ValueError, match="Tool name 'wrong_add' does not start with prefix 'math_'"): await prefixed_toolset.call_tool(ToolCallPart(tool_name='wrong_add', args={'a': 1, 'b': 2}), context) - # Test calling with no prefix - with pytest.raises(ValueError, match="Tool name 'add' does not start with prefix 'math_'"): - await prefixed_toolset.call_tool(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2}), context) - - # Test calling with partial prefix - with pytest.raises(ValueError, match="Tool name 'mat_add' does not start with prefix 'math_'"): - await prefixed_toolset.call_tool(ToolCallPart(tool_name='mat_add', args={'a': 1, 'b': 2}), context) - - -async def test_prefixed_toolset_empty_prefix(): - """Test that PrefixedToolset works correctly with an empty prefix.""" - context = build_run_context(None) - base_toolset = FunctionToolset[None]() - - @base_toolset.tool - def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - - prefixed_toolset = PrefixedToolset(base_toolset, '') - - # Check that tool names have empty prefix (just underscore) - assert prefixed_toolset.tool_names == ['_add'] - - # Test calling the tool - result = await prefixed_toolset.call_tool(ToolCallPart(tool_name='_add', args={'a': 3, 'b': 4}), context) - assert result == 7 - - # Test error for wrong name - with pytest.raises(ValueError, match="Tool name 'add' does not start with prefix '_'"): - await prefixed_toolset.call_tool(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2}), context) - async def test_comprehensive_toolset_composition(): """Test that all toolsets can be composed together and work correctly.""" @@ -381,12 +229,12 @@ def add(a: int, b: int) -> int: @math_toolset.tool def subtract(a: int, b: int) -> int: """Subtract two numbers""" - return a - b + return a - b # pragma: no cover @math_toolset.tool def multiply(a: int, b: int) -> int: """Multiply two numbers""" - return a * b + return a * b # pragma: no cover # Create second FunctionToolset with string operations string_toolset = FunctionToolset[TestDeps]() @@ -399,12 +247,12 @@ def concat(s1: str, s2: str) -> str: @string_toolset.tool def uppercase(text: str) -> str: """Convert text to uppercase""" - return text.upper() + return text.upper() # pragma: no cover @string_toolset.tool def reverse(text: str) -> str: """Reverse a string""" - return text[::-1] + return text[::-1] # pragma: no cover # Create third FunctionToolset with advanced operations advanced_toolset = FunctionToolset[TestDeps]() @@ -412,18 +260,7 @@ def reverse(text: str) -> str: @advanced_toolset.tool def power(base: int, exponent: int) -> int: """Calculate base raised to the power of exponent""" - return base**exponent - - @advanced_toolset.tool - def factorial(n: int) -> int: - """Calculate factorial of n""" - - def _fact(x: int) -> int: - if x <= 1: - return 1 - return x * _fact(x - 1) - - return _fact(n) + return base**exponent # pragma: no cover # Step 1: Prefix each FunctionToolset individually prefixed_math = PrefixedToolset(math_toolset, 'math') @@ -501,16 +338,6 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef 'type': 'object', }, ), - ToolDefinition( - name='adv_factorial', - description='Calculate factorial of n (role: user)', - parameters_json_schema={ - 'additionalProperties': False, - 'properties': {'n': {'type': 'integer'}}, - 'required': ['n'], - 'type': 'object', - }, - ), ] ) # Call a tool and check result @@ -593,16 +420,6 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef 'type': 'object', }, ), - ToolDefinition( - name='adv_factorial', - description='Calculate factorial of n (role: admin)', - parameters_json_schema={ - 'additionalProperties': False, - 'properties': {'n': {'type': 'integer'}}, - 'required': ['n'], - 'type': 'object', - }, - ), ] ) result = await admin_final_toolset.call_tool( @@ -656,3 +473,22 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef toolset_once = await prepared_toolset.prepare_for_run(ctx2) toolset_twice = await (await prepared_toolset.prepare_for_run(ctx1)).prepare_for_run(ctx2) assert toolset_once == toolset_twice + + +async def test_context_manager(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: + return + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + + async with toolset: + assert server1.is_running + assert server2.is_running + + async with toolset: + assert server1.is_running + assert server2.is_running From b2aa894350da337cec80e4f9635c074395f39067 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 4 Jul 2025 21:42:46 +0000 Subject: [PATCH 89/90] Improve coverage --- tests/test_toolsets.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index 0c18ca36a..9c2b0817a 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -177,6 +177,11 @@ def add(a: int, b: int) -> int: """Add two numbers""" return a + b # pragma: no cover + @base_toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b # pragma: no cover + async def prepare_change_names(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: # Try to change the name of an existing tool modified_tool_defs: list[ToolDefinition] = [] From 1c2d221fa2ea544f394726d8ed0b2b59de71481c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 8 Jul 2025 15:48:40 +0000 Subject: [PATCH 90/90] Address feedback --- docs/mcp/client.md | 4 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 18 +++--- pydantic_ai_slim/pydantic_ai/_output.py | 12 ++-- pydantic_ai_slim/pydantic_ai/agent.py | 24 ++++---- pydantic_ai_slim/pydantic_ai/mcp.py | 58 +++++++++++-------- pydantic_ai_slim/pydantic_ai/tools.py | 1 + .../pydantic_ai/toolsets/__init__.py | 2 +- .../pydantic_ai/toolsets/_callable.py | 4 +- .../pydantic_ai/toolsets/combined.py | 34 ++++++----- .../pydantic_ai/toolsets/deferred.py | 5 +- .../pydantic_ai/toolsets/wrapper.py | 4 +- tests/test_agent.py | 4 +- tests/test_toolsets.py | 3 +- 13 files changed, 101 insertions(+), 72 deletions(-) diff --git a/docs/mcp/client.md b/docs/mcp/client.md index a12da956d..cea2023f7 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -137,7 +137,9 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class. !!! note - When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager is responsible for starting and stopping the server. + When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers as `toolsets` on an [`Agent`][pydantic_ai.Agent], you can use the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager to start and stop the server around the context where it'll be used. You can also use [`async with server`][pydantic_ai.mcp.MCPServerStdio.__aenter__] to manage the starting and stopping of a specific server, for example if you'd like to use it with multiple agents. + + If you don't explicitly start the server using one of these context managers, it will automatically be started when it's needed (e.g. to list the available tools or call a specific tool), but it's more efficient to do so around the entire context where you expect it to be used. ```python {title="mcp_stdio_client.py" py="3.10"} from pydantic_ai import Agent diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e4fd9f38d..3c9c8f0dc 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -639,14 +639,16 @@ async def process_function_tools( # noqa: C901 # Then, we handle function tool calls calls_to_run: list[_messages.ToolCallPart] = [] if final_result and ctx.deps.end_strategy == 'early': - for call in tool_calls_by_kind['function']: - output_parts.append( + output_parts.extend( + [ _messages.ToolReturnPart( tool_name=call.tool_name, content='Tool not executed - a final result was already processed.', tool_call_id=call.tool_call_id, ) - ) + for call in tool_calls_by_kind['function'] + ] + ) else: calls_to_run.extend(tool_calls_by_kind['function']) @@ -776,8 +778,8 @@ async def _call_function_tool( def process_content(content: Any) -> Any: if isinstance(content, _messages.ToolReturn): raise exceptions.UserError( - f"{tool_call.tool_name}'s return contains invalid nested ToolReturn objects. " - f'ToolReturn should be used directly.' + f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. ' + f'`ToolReturn` should be used directly.' ) elif isinstance(content, _messages.MultiModalContentTypes): if isinstance(content, _messages.BinaryContent): @@ -792,8 +794,8 @@ def process_content(content: Any) -> Any: ) ) return f'See file {identifier}' - else: - return content + + return content if isinstance(tool_result, _messages.ToolReturn): if ( @@ -805,7 +807,7 @@ def process_content(content: Any) -> Any: ) ): raise exceptions.UserError( - f"{tool_call.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. " + f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. ' f'Please use `content` instead.' ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index bc2102eaf..4f476eb66 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -294,7 +294,7 @@ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDa @dataclass(init=False) class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] - _toolset: OutputToolset[Any] | None = None + _toolset: OutputToolset[Any] | None def __init__( self, @@ -477,7 +477,7 @@ async def process( @dataclass(init=False) class ToolOutputSchema(OutputSchema[OutputDataT]): - _toolset: OutputToolset[Any] | None = None + _toolset: OutputToolset[Any] | None def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool): super().__init__(allows_deferred_tool_calls) @@ -834,8 +834,8 @@ class OutputToolset(CallableToolset[AgentDepsT]): _tool_defs: list[ToolDefinition] processors: dict[str, ObjectOutputProcessor[Any]] - max_retries: int = field(default=1) - output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list) + max_retries: int + output_validators: list[OutputValidator[AgentDepsT, Any]] @classmethod def build( @@ -910,12 +910,12 @@ def __init__( tool_defs: list[ToolDefinition], processors: dict[str, ObjectOutputProcessor[Any]], max_retries: int = 1, - output_validators: list[OutputValidator[AgentDepsT, Any]] = [], + output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None, ): self.processors = processors self._tool_defs = tool_defs self.max_retries = max_retries - self.output_validators = output_validators + self.output_validators = output_validators or [] async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: return RunToolset(self, ctx) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 7c2b5531f..c46699fd0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -4,6 +4,7 @@ import inspect import json import warnings +from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar @@ -166,6 +167,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) + _enter_lock: Lock = dataclasses.field(repr=False) _entered_count: int = dataclasses.field(repr=False) _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) @@ -433,8 +435,9 @@ def __init__( self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) - self._exit_stack = None + self._enter_lock = Lock() self._entered_count = 0 + self._exit_stack = None @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: @@ -1795,18 +1798,19 @@ def is_end_node( return isinstance(node, End) async def __aenter__(self) -> Self: - """Enter the agent. This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered with the agent so they can be used in a run.""" - if self._entered_count == 0: - self._exit_stack = AsyncExitStack() - await self._exit_stack.enter_async_context(self._toolset) - self._entered_count += 1 + async with self._enter_lock: + if self._entered_count == 0: + self._exit_stack = AsyncExitStack() + await self._exit_stack.enter_async_context(self._toolset) + self._entered_count += 1 return self async def __aexit__(self, *args: Any) -> bool | None: - self._entered_count -= 1 - if self._entered_count <= 0 and self._exit_stack is not None: - await self._exit_stack.aclose() - self._exit_stack = None + async with self._enter_lock: + self._entered_count -= 1 + if self._entered_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None: """Set the sampling model on all MCP servers registered with the agent. diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index e1878e3f7..0770ab2e9 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -3,9 +3,10 @@ import base64 import functools from abc import ABC, abstractmethod +from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable @@ -60,12 +61,18 @@ class MCPServer(CallableToolset[Any], ABC): sampling_model: models.Model | None = None # } end of "abstract fields" - _running_count: int = 0 + _enter_lock: Lock = field(compare=False) + _running_count: int + _exit_stack: AsyncExitStack | None _client: ClientSession _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] - _exit_stack: AsyncExitStack + + def __post_init__(self): + self._enter_lock = Lock() + self._running_count = 0 + self._exit_stack = None @abstractmethod @asynccontextmanager @@ -86,7 +93,7 @@ def name(self) -> str: return repr(self) @property - def tool_name_conflict_hint(self) -> str: + def _tool_name_conflict_hint(self) -> str: return 'Consider setting `tool_prefix` to avoid name conflicts.' async def list_tools(self) -> list[mcp_types.Tool]: @@ -188,30 +195,35 @@ def _max_retries_for_tool(self, name: str) -> int: return self.max_retries async def __aenter__(self) -> Self: - if self._running_count == 0: - self._exit_stack = AsyncExitStack() - - self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams()) - client = ClientSession( - read_stream=self._read_stream, - write_stream=self._write_stream, - sampling_callback=self._sampling_callback if self.allow_sampling else None, - logging_callback=self.log_handler, - ) - self._client = await self._exit_stack.enter_async_context(client) + async with self._enter_lock: + if self._running_count == 0: + self._exit_stack = AsyncExitStack() + + self._read_stream, self._write_stream = await self._exit_stack.enter_async_context( + self.client_streams() + ) + client = ClientSession( + read_stream=self._read_stream, + write_stream=self._write_stream, + sampling_callback=self._sampling_callback if self.allow_sampling else None, + logging_callback=self.log_handler, + ) + self._client = await self._exit_stack.enter_async_context(client) - with anyio.fail_after(self.timeout): - await self._client.initialize() + with anyio.fail_after(self.timeout): + await self._client.initialize() - if log_level := self.log_level: - await self._client.set_logging_level(log_level) - self._running_count += 1 + if log_level := self.log_level: + await self._client.set_logging_level(log_level) + self._running_count += 1 return self async def __aexit__(self, *args: Any) -> bool | None: - self._running_count -= 1 - if self._running_count <= 0: - await self._exit_stack.aclose() + async with self._enter_lock: + self._running_count -= 1 + if self._running_count <= 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None @property def is_running(self) -> bool: diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index e79d932a8..299c03e98 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -364,6 +364,7 @@ class ToolDefinition: kind: ToolKind = field(default='function') """The kind of tool: + - `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model - `'output'`: a tool that passes through an output value that ends the run - `'deferred'`: a tool that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via e.g. [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py index f94a57efa..8d5feaf24 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -27,7 +27,7 @@ def name(self) -> str: return self.__class__.__name__.replace('Toolset', ' toolset') @property - def tool_name_conflict_hint(self) -> str: + def _tool_name_conflict_hint(self) -> str: return 'Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.' async def __aenter__(self) -> Self: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/_callable.py b/pydantic_ai_slim/pydantic_ai/toolsets/_callable.py index 3b8562e82..e08ab751c 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/_callable.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/_callable.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import replace -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any from pydantic_core import SchemaValidator @@ -28,7 +28,7 @@ async def _call_tool(self, ctx: RunContext[AgentDepsT], name: str, tool_args: di async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any: ctx = replace(ctx, tool_name=call.tool_name, tool_call_id=call.tool_call_id) - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' validator = self._get_tool_args_validator(ctx, call.tool_name) if isinstance(call.args, str): args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index ce49aca1d..cc3385e6c 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Sequence from contextlib import AsyncExitStack -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Callable from typing_extensions import Self @@ -22,12 +22,16 @@ class CombinedToolset(AbstractToolset[AgentDepsT]): toolsets: list[AbstractToolset[AgentDepsT]] _toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]] - _exit_stack: AsyncExitStack | None + + _enter_lock: asyncio.Lock = field(compare=False) _entered_count: int + _exit_stack: AsyncExitStack | None def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): - self._exit_stack = None + self._enter_lock = asyncio.Lock() self._entered_count = 0 + self._exit_stack = None + self.toolsets = list(toolsets) self._toolset_per_tool_name = {} @@ -36,28 +40,30 @@ def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]): try: existing_toolset = self._toolset_per_tool_name[name] raise UserError( - f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}' + f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_toolset.name}: {name!r}. {toolset._tool_name_conflict_hint}' ) except KeyError: pass self._toolset_per_tool_name[name] = toolset async def __aenter__(self) -> Self: - if self._entered_count == 0: - self._exit_stack = AsyncExitStack() - for toolset in self.toolsets: - await self._exit_stack.enter_async_context(toolset) - self._entered_count += 1 + async with self._enter_lock: + if self._entered_count == 0: + self._exit_stack = AsyncExitStack() + for toolset in self.toolsets: + await self._exit_stack.enter_async_context(toolset) + self._entered_count += 1 return self async def __aexit__(self, *args: Any) -> bool | None: - self._entered_count -= 1 - if self._entered_count <= 0 and self._exit_stack is not None: - await self._exit_stack.aclose() - self._exit_stack = None + async with self._enter_lock: + self._entered_count -= 1 + if self._entered_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]: - toolsets_for_run = await asyncio.gather(*[toolset.prepare_for_run(ctx) for toolset in self.toolsets]) + toolsets_for_run = await asyncio.gather(*(toolset.prepare_for_run(ctx) for toolset in self.toolsets)) combined_for_run = CombinedToolset(toolsets_for_run) return RunToolset(combined_for_run, ctx) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py index ac039182e..21d5eab1a 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -11,7 +11,10 @@ class DeferredToolset(AbstractToolset[AgentDepsT]): - """A toolset that holds deferred tool.""" + """A toolset that holds deferred tools. + + See [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind] for more information about deferred tools. + """ _tool_defs: list[ToolDefinition] diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 3ae7ed3cc..7956f2d22 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -26,8 +26,8 @@ def name(self) -> str: return self.wrapped.name @property - def tool_name_conflict_hint(self) -> str: - return self.wrapped.tool_name_conflict_hint + def _tool_name_conflict_hint(self) -> str: + return self.wrapped._tool_name_conflict_hint async def __aenter__(self) -> Self: await self.wrapped.__aenter__() diff --git a/tests/test_agent.py b/tests/test_agent.py index 8e76fbd12..fd8f5e538 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3345,7 +3345,7 @@ def analyze_data() -> list[Any]: with pytest.raises( UserError, - match="analyze_data's return contains invalid nested ToolReturn objects. ToolReturn should be used directly.", + match="The return value of tool 'analyze_data' contains invalid nested `ToolReturn` objects. `ToolReturn` should be used directly.", ): agent.run_sync('Please analyze the data') @@ -3379,7 +3379,7 @@ def analyze_data() -> ToolReturn: with pytest.raises( UserError, - match="analyze_data's `return_value` contains invalid nested MultiModalContentTypes objects. Please use `content` instead.", + match="The `return_value` of tool 'analyze_data' contains invalid nested `MultiModalContentTypes` objects. Please use `content` instead.", ): agent.run_sync('Please analyze the data') diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index 9c2b0817a..ba2ec479c 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -472,7 +472,6 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef ) # Test prepare_for_run idempotency - # toolset.prepare_for_run(ctx1).prepare_for_run(ctx2) == toolset.prepare_for_run(ctx2) ctx1 = build_run_context(TestDeps(user_role='user', enable_advanced=True)) ctx2 = build_run_context(TestDeps(user_role='admin', enable_advanced=True)) toolset_once = await prepared_toolset.prepare_for_run(ctx2) @@ -484,7 +483,7 @@ async def test_context_manager(): try: from pydantic_ai.mcp import MCPServerStdio except ImportError: - return + pytest.skip('mcp is not installed') server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])