Skip to content

Commit 2348f45

Browse files
committed
Require DeferredToolCalls to be used with other output type
1 parent 0c96126 commit 2348f45

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ def build( # noqa: C901
190190

191191
outputs = [output for output in raw_outputs if output is not DeferredToolCalls]
192192
deferred_tool_calls = len(outputs) < len(raw_outputs)
193+
if len(outputs) == 0:
194+
if deferred_tool_calls:
195+
raise UserError('At least one output type must be provided other than DeferredToolCalls.')
196+
else:
197+
raise UserError('At least one output type must be provided.')
198+
193199
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
194200
if len(outputs) > 1:
195201
raise UserError('NativeOutput cannot be mixed with other output types.')

tests/test_tools.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,3 +1211,32 @@ def test_deferred_tool():
12111211
},
12121212
)
12131213
)
1214+
1215+
1216+
def test_deferred_tool_with_output_type():
1217+
class MyModel(BaseModel):
1218+
foo: str
1219+
1220+
deferred_toolset = DeferredToolset(
1221+
[
1222+
ToolDefinition(
1223+
name='my_tool',
1224+
description='',
1225+
parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']},
1226+
),
1227+
]
1228+
)
1229+
agent = Agent(TestModel(call_tools=[]), output_type=[MyModel, DeferredToolCalls], toolsets=[deferred_toolset])
1230+
1231+
result = agent.run_sync('Hello')
1232+
assert result.output == snapshot(MyModel(foo='a'))
1233+
1234+
1235+
def test_output_type_deferred_tool_calls_by_itself():
1236+
with pytest.raises(UserError, match='At least one output type must be provided other than DeferredToolCalls.'):
1237+
Agent(TestModel(), output_type=DeferredToolCalls)
1238+
1239+
1240+
def test_output_type_empty():
1241+
with pytest.raises(UserError, match='At least one output type must be provided.'):
1242+
Agent(TestModel(), output_type=[])

0 commit comments

Comments
 (0)