Skip to content

Commit f7db040

Browse files
committed
Allow OutputSpec to be nested
1 parent ebf6f40 commit f7db040

File tree

3 files changed

+110
-80
lines changed

3 files changed

+110
-80
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 87 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from abc import ABC, abstractmethod
66
from collections.abc import Awaitable, Sequence
77
from dataclasses import dataclass, field, replace
8-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
8+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Self, Union, cast, overload
99

1010
from pydantic import TypeAdapter, ValidationError
1111
from pydantic_core import SchemaValidator
@@ -26,6 +26,7 @@
2626
TextOutput,
2727
TextOutputFunc,
2828
ToolOutput,
29+
_OutputSpecItem, # type: ignore[reportPrivateUsage]
2930
)
3031
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
3132
from .toolsets import AbstractToolset
@@ -233,7 +234,7 @@ def build( # noqa: C901
233234
else:
234235
other_outputs.append(output)
235236

236-
toolset = cls._build_toolset(tool_outputs + other_outputs, name=name, description=description, strict=strict)
237+
toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict)
237238

238239
if len(text_outputs) > 0:
239240
if len(text_outputs) > 1:
@@ -268,73 +269,6 @@ def build( # noqa: C901
268269

269270
raise UserError('At least one output type must be provided.')
270271

271-
@staticmethod
272-
def _build_toolset(
273-
outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
274-
name: str | None = None,
275-
description: str | None = None,
276-
strict: bool | None = None,
277-
) -> OutputToolset[Any] | None:
278-
if len(outputs) == 0:
279-
return None
280-
281-
processors: dict[str, ObjectOutputProcessor[Any]] = {}
282-
tool_defs: list[ToolDefinition] = []
283-
284-
default_name = name or DEFAULT_OUTPUT_TOOL_NAME
285-
default_description = description
286-
default_strict = strict
287-
288-
multiple = len(outputs) > 1
289-
for output in outputs:
290-
name = None
291-
description = None
292-
strict = None
293-
if isinstance(output, ToolOutput):
294-
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
295-
name = output.name
296-
description = output.description
297-
strict = output.strict
298-
299-
output = output.output
300-
301-
if name is None:
302-
name = default_name
303-
if multiple:
304-
name += f'_{output.__name__}'
305-
306-
i = 1
307-
original_name = name
308-
while name in processors:
309-
i += 1
310-
name = f'{original_name}_{i}'
311-
312-
description = description or default_description
313-
if strict is None:
314-
strict = default_strict
315-
316-
processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
317-
object_def = processor.object_def
318-
319-
description = object_def.description
320-
if not description:
321-
description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
322-
if multiple:
323-
description = f'{object_def.name}: {description}'
324-
325-
tool_def = ToolDefinition(
326-
name=name,
327-
description=description,
328-
parameters_json_schema=object_def.json_schema,
329-
strict=object_def.strict,
330-
outer_typed_dict_key=processor.outer_typed_dict_key,
331-
kind='output',
332-
)
333-
processors[name] = processor
334-
tool_defs.append(tool_def)
335-
336-
return OutputToolset(processors=processors, tool_defs=tool_defs)
337-
338272
@staticmethod
339273
def _build_processor(
340274
outputs: Sequence[OutputTypeOrFunction[OutputDataT]],
@@ -908,6 +842,74 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
908842
max_retries: int = field(default=1)
909843
output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list)
910844

845+
@classmethod
846+
def build(
847+
cls,
848+
outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]],
849+
name: str | None = None,
850+
description: str | None = None,
851+
strict: bool | None = None,
852+
) -> Self | None:
853+
if len(outputs) == 0:
854+
return None
855+
856+
processors: dict[str, ObjectOutputProcessor[Any]] = {}
857+
tool_defs: list[ToolDefinition] = []
858+
859+
default_name = name or DEFAULT_OUTPUT_TOOL_NAME
860+
default_description = description
861+
default_strict = strict
862+
863+
multiple = len(outputs) > 1
864+
for output in outputs:
865+
name = None
866+
description = None
867+
strict = None
868+
if isinstance(output, ToolOutput):
869+
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
870+
name = output.name
871+
description = output.description
872+
strict = output.strict
873+
874+
output = output.output
875+
876+
if name is None:
877+
name = default_name
878+
if multiple:
879+
name += f'_{output.__name__}'
880+
881+
i = 1
882+
original_name = name
883+
while name in processors:
884+
i += 1
885+
name = f'{original_name}_{i}'
886+
887+
description = description or default_description
888+
if strict is None:
889+
strict = default_strict
890+
891+
processor = ObjectOutputProcessor(output=output, description=description, strict=strict)
892+
object_def = processor.object_def
893+
894+
description = object_def.description
895+
if not description:
896+
description = DEFAULT_OUTPUT_TOOL_DESCRIPTION
897+
if multiple:
898+
description = f'{object_def.name}: {description}'
899+
900+
tool_def = ToolDefinition(
901+
name=name,
902+
description=description,
903+
parameters_json_schema=object_def.json_schema,
904+
strict=object_def.strict,
905+
outer_typed_dict_key=processor.outer_typed_dict_key,
906+
kind='output',
907+
)
908+
processors[name] = processor
909+
tool_defs.append(tool_def)
910+
911+
return cls(processors=processors, tool_defs=tool_defs)
912+
911913
def __init__(
912914
self,
913915
tool_defs: list[ToolDefinition],
@@ -942,17 +944,29 @@ async def call_tool(
942944
return output
943945

944946

945-
def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]:
946-
outputs: Sequence[T]
947+
@overload
948+
def _flatten_output_spec(
949+
output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]],
950+
) -> Sequence[OutputTypeOrFunction[T]]: ...
951+
952+
953+
@overload
954+
def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ...
955+
956+
957+
def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]:
958+
outputs: Sequence[OutputSpec[T]]
947959
if isinstance(output_spec, Sequence):
948960
outputs = output_spec
949961
else:
950962
outputs = (output_spec,)
951963

952-
outputs_flat: list[T] = []
964+
outputs_flat: list[_OutputSpecItem[T]] = []
953965
for output in outputs:
966+
if isinstance(output, Sequence):
967+
outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output)))
954968
if union_types := _utils.get_union_args(output):
955969
outputs_flat.extend(union_types)
956970
else:
957-
outputs_flat.append(output)
971+
outputs_flat.append(cast(_OutputSpecItem[T], output))
958972
return outputs_flat

pydantic_ai_slim/pydantic_ai/output.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,17 @@ def split_into_words(text: str) -> list[str]:
267267
"""The function that will be called to process the model's plain text output. The function must take a single string argument."""
268268

269269

270+
_OutputSpecItem = TypeAliasType(
271+
'_OutputSpecItem',
272+
Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], NativeOutput[T_co], PromptedOutput[T_co], TextOutput[T_co]],
273+
type_params=(T_co,),
274+
)
275+
270276
OutputSpec = TypeAliasType(
271277
'OutputSpec',
272278
Union[
273-
OutputTypeOrFunction[T_co],
274-
ToolOutput[T_co],
275-
NativeOutput[T_co],
276-
PromptedOutput[T_co],
277-
TextOutput[T_co],
278-
Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]],
279+
_OutputSpecItem[T_co],
280+
Sequence['OutputSpec[T_co]'],
279281
],
280282
type_params=(T_co,),
281283
)

tests/typed_agent.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from pydantic_ai import Agent, ModelRetry, RunContext, Tool
1212
from pydantic_ai.agent import AgentRunResult
13-
from pydantic_ai.output import TextOutput, ToolOutput
13+
from pydantic_ai.output import DeferredToolCalls, TextOutput, ToolOutput
1414
from pydantic_ai.tools import ToolDefinition
1515

1616
# Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True
@@ -212,6 +212,14 @@ def my_method(self) -> bool:
212212
assert_type(
213213
complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]]
214214
)
215+
216+
complex_deferred_output_agent = Agent[
217+
None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls
218+
](output_type=[complex_output_agent.output_type, DeferredToolCalls])
219+
assert_type(
220+
complex_deferred_output_agent,
221+
Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls],
222+
)
215223
else:
216224
# pyright is able to correctly infer the type here
217225
async_int_function_agent = Agent(output_type=foobar_plain)
@@ -231,6 +239,12 @@ def my_method(self) -> bool:
231239
complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]]
232240
)
233241

242+
complex_deferred_output_agent = Agent(output_type=[complex_output_agent.output_type, DeferredToolCalls])
243+
assert_type(
244+
complex_deferred_output_agent,
245+
Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls],
246+
)
247+
234248

235249
Tool(foobar_ctx, takes_ctx=True)
236250
Tool(foobar_ctx)

0 commit comments

Comments
 (0)