|
5 | 5 | from abc import ABC, abstractmethod
|
6 | 6 | from collections.abc import Awaitable, Sequence
|
7 | 7 | from dataclasses import dataclass, field, replace
|
8 |
| -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload |
| 8 | +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Self, Union, cast, overload |
9 | 9 |
|
10 | 10 | from pydantic import TypeAdapter, ValidationError
|
11 | 11 | from pydantic_core import SchemaValidator
|
|
26 | 26 | TextOutput,
|
27 | 27 | TextOutputFunc,
|
28 | 28 | ToolOutput,
|
| 29 | + _OutputSpecItem, # type: ignore[reportPrivateUsage] |
29 | 30 | )
|
30 | 31 | from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
|
31 | 32 | from .toolsets import AbstractToolset
|
@@ -233,7 +234,7 @@ def build( # noqa: C901
|
233 | 234 | else:
|
234 | 235 | other_outputs.append(output)
|
235 | 236 |
|
236 |
| - toolset = cls._build_toolset(tool_outputs + other_outputs, name=name, description=description, strict=strict) |
| 237 | + toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict) |
237 | 238 |
|
238 | 239 | if len(text_outputs) > 0:
|
239 | 240 | if len(text_outputs) > 1:
|
@@ -268,73 +269,6 @@ def build( # noqa: C901
|
268 | 269 |
|
269 | 270 | raise UserError('At least one output type must be provided.')
|
270 | 271 |
|
271 |
| - @staticmethod |
272 |
| - def _build_toolset( |
273 |
| - outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], |
274 |
| - name: str | None = None, |
275 |
| - description: str | None = None, |
276 |
| - strict: bool | None = None, |
277 |
| - ) -> OutputToolset[Any] | None: |
278 |
| - if len(outputs) == 0: |
279 |
| - return None |
280 |
| - |
281 |
| - processors: dict[str, ObjectOutputProcessor[Any]] = {} |
282 |
| - tool_defs: list[ToolDefinition] = [] |
283 |
| - |
284 |
| - default_name = name or DEFAULT_OUTPUT_TOOL_NAME |
285 |
| - default_description = description |
286 |
| - default_strict = strict |
287 |
| - |
288 |
| - multiple = len(outputs) > 1 |
289 |
| - for output in outputs: |
290 |
| - name = None |
291 |
| - description = None |
292 |
| - strict = None |
293 |
| - if isinstance(output, ToolOutput): |
294 |
| - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads |
295 |
| - name = output.name |
296 |
| - description = output.description |
297 |
| - strict = output.strict |
298 |
| - |
299 |
| - output = output.output |
300 |
| - |
301 |
| - if name is None: |
302 |
| - name = default_name |
303 |
| - if multiple: |
304 |
| - name += f'_{output.__name__}' |
305 |
| - |
306 |
| - i = 1 |
307 |
| - original_name = name |
308 |
| - while name in processors: |
309 |
| - i += 1 |
310 |
| - name = f'{original_name}_{i}' |
311 |
| - |
312 |
| - description = description or default_description |
313 |
| - if strict is None: |
314 |
| - strict = default_strict |
315 |
| - |
316 |
| - processor = ObjectOutputProcessor(output=output, description=description, strict=strict) |
317 |
| - object_def = processor.object_def |
318 |
| - |
319 |
| - description = object_def.description |
320 |
| - if not description: |
321 |
| - description = DEFAULT_OUTPUT_TOOL_DESCRIPTION |
322 |
| - if multiple: |
323 |
| - description = f'{object_def.name}: {description}' |
324 |
| - |
325 |
| - tool_def = ToolDefinition( |
326 |
| - name=name, |
327 |
| - description=description, |
328 |
| - parameters_json_schema=object_def.json_schema, |
329 |
| - strict=object_def.strict, |
330 |
| - outer_typed_dict_key=processor.outer_typed_dict_key, |
331 |
| - kind='output', |
332 |
| - ) |
333 |
| - processors[name] = processor |
334 |
| - tool_defs.append(tool_def) |
335 |
| - |
336 |
| - return OutputToolset(processors=processors, tool_defs=tool_defs) |
337 |
| - |
338 | 272 | @staticmethod
|
339 | 273 | def _build_processor(
|
340 | 274 | outputs: Sequence[OutputTypeOrFunction[OutputDataT]],
|
@@ -908,6 +842,74 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
|
908 | 842 | max_retries: int = field(default=1)
|
909 | 843 | output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list)
|
910 | 844 |
|
| 845 | + @classmethod |
| 846 | + def build( |
| 847 | + cls, |
| 848 | + outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], |
| 849 | + name: str | None = None, |
| 850 | + description: str | None = None, |
| 851 | + strict: bool | None = None, |
| 852 | + ) -> Self | None: |
| 853 | + if len(outputs) == 0: |
| 854 | + return None |
| 855 | + |
| 856 | + processors: dict[str, ObjectOutputProcessor[Any]] = {} |
| 857 | + tool_defs: list[ToolDefinition] = [] |
| 858 | + |
| 859 | + default_name = name or DEFAULT_OUTPUT_TOOL_NAME |
| 860 | + default_description = description |
| 861 | + default_strict = strict |
| 862 | + |
| 863 | + multiple = len(outputs) > 1 |
| 864 | + for output in outputs: |
| 865 | + name = None |
| 866 | + description = None |
| 867 | + strict = None |
| 868 | + if isinstance(output, ToolOutput): |
| 869 | + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads |
| 870 | + name = output.name |
| 871 | + description = output.description |
| 872 | + strict = output.strict |
| 873 | + |
| 874 | + output = output.output |
| 875 | + |
| 876 | + if name is None: |
| 877 | + name = default_name |
| 878 | + if multiple: |
| 879 | + name += f'_{output.__name__}' |
| 880 | + |
| 881 | + i = 1 |
| 882 | + original_name = name |
| 883 | + while name in processors: |
| 884 | + i += 1 |
| 885 | + name = f'{original_name}_{i}' |
| 886 | + |
| 887 | + description = description or default_description |
| 888 | + if strict is None: |
| 889 | + strict = default_strict |
| 890 | + |
| 891 | + processor = ObjectOutputProcessor(output=output, description=description, strict=strict) |
| 892 | + object_def = processor.object_def |
| 893 | + |
| 894 | + description = object_def.description |
| 895 | + if not description: |
| 896 | + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION |
| 897 | + if multiple: |
| 898 | + description = f'{object_def.name}: {description}' |
| 899 | + |
| 900 | + tool_def = ToolDefinition( |
| 901 | + name=name, |
| 902 | + description=description, |
| 903 | + parameters_json_schema=object_def.json_schema, |
| 904 | + strict=object_def.strict, |
| 905 | + outer_typed_dict_key=processor.outer_typed_dict_key, |
| 906 | + kind='output', |
| 907 | + ) |
| 908 | + processors[name] = processor |
| 909 | + tool_defs.append(tool_def) |
| 910 | + |
| 911 | + return cls(processors=processors, tool_defs=tool_defs) |
| 912 | + |
911 | 913 | def __init__(
|
912 | 914 | self,
|
913 | 915 | tool_defs: list[ToolDefinition],
|
@@ -942,17 +944,29 @@ async def call_tool(
|
942 | 944 | return output
|
943 | 945 |
|
944 | 946 |
|
945 |
| -def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: |
946 |
| - outputs: Sequence[T] |
| 947 | +@overload |
| 948 | +def _flatten_output_spec( |
| 949 | + output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]], |
| 950 | +) -> Sequence[OutputTypeOrFunction[T]]: ... |
| 951 | + |
| 952 | + |
| 953 | +@overload |
| 954 | +def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ... |
| 955 | + |
| 956 | + |
| 957 | +def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: |
| 958 | + outputs: Sequence[OutputSpec[T]] |
947 | 959 | if isinstance(output_spec, Sequence):
|
948 | 960 | outputs = output_spec
|
949 | 961 | else:
|
950 | 962 | outputs = (output_spec,)
|
951 | 963 |
|
952 |
| - outputs_flat: list[T] = [] |
| 964 | + outputs_flat: list[_OutputSpecItem[T]] = [] |
953 | 965 | for output in outputs:
|
| 966 | + if isinstance(output, Sequence): |
| 967 | + outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output))) |
954 | 968 | if union_types := _utils.get_union_args(output):
|
955 | 969 | outputs_flat.extend(union_types)
|
956 | 970 | else:
|
957 |
| - outputs_flat.append(output) |
| 971 | + outputs_flat.append(cast(_OutputSpecItem[T], output)) |
958 | 972 | return outputs_flat
|
0 commit comments