@@ -155,7 +155,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
155
155
)
156
156
_function_toolset : FunctionToolset [AgentDepsT ] = dataclasses .field (repr = False )
157
157
_output_toolset : OutputToolset [AgentDepsT ] = dataclasses .field (repr = False )
158
+ _user_toolsets : Sequence [AbstractToolset [AgentDepsT ]] = dataclasses .field (repr = False )
159
+ _mcp_servers : Sequence [MCPServer ] = dataclasses .field (repr = False )
158
160
_toolset : AbstractToolset [AgentDepsT ] = dataclasses .field (repr = False )
161
+ _prepare_tools : ToolsPrepareFunc [AgentDepsT ] | None = dataclasses .field (repr = False )
159
162
_max_result_retries : int = dataclasses .field (repr = False )
160
163
_override_deps : _utils .Option [AgentDepsT ] = dataclasses .field (default = None , repr = False )
161
164
_override_model : _utils .Option [models .Model ] = dataclasses .field (default = None , repr = False )
@@ -179,7 +182,7 @@ def __init__(
179
182
tools : Sequence [Tool [AgentDepsT ] | ToolFuncEither [AgentDepsT , ...]] = (),
180
183
prepare_tools : ToolsPrepareFunc [AgentDepsT ] | None = None ,
181
184
mcp_servers : Sequence [MCPServer ] = (),
182
- toolsets : Sequence [AbstractToolset [AgentDepsT ]] = () ,
185
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
183
186
defer_model_check : bool = False ,
184
187
end_strategy : EndStrategy = 'early' ,
185
188
instrument : InstrumentationSettings | bool | None = None ,
@@ -210,7 +213,7 @@ def __init__(
210
213
tools : Sequence [Tool [AgentDepsT ] | ToolFuncEither [AgentDepsT , ...]] = (),
211
214
prepare_tools : ToolsPrepareFunc [AgentDepsT ] | None = None ,
212
215
mcp_servers : Sequence [MCPServer ] = (),
213
- toolsets : Sequence [AbstractToolset [AgentDepsT ]] = () ,
216
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
214
217
defer_model_check : bool = False ,
215
218
end_strategy : EndStrategy = 'early' ,
216
219
instrument : InstrumentationSettings | bool | None = None ,
@@ -238,7 +241,7 @@ def __init__(
238
241
mcp_servers : Sequence [
239
242
MCPServer
240
243
] = (), # TODO: Deprecate argument, MCPServers can be passed directly to toolsets
241
- toolsets : Sequence [AbstractToolset [AgentDepsT ]] = () ,
244
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
242
245
defer_model_check : bool = False ,
243
246
end_strategy : EndStrategy = 'early' ,
244
247
instrument : InstrumentationSettings | bool | None = None ,
@@ -361,19 +364,18 @@ def __init__(
361
364
self ._system_prompt_dynamic_functions = {}
362
365
363
366
self ._max_result_retries = output_retries if output_retries is not None else retries
367
+ self ._prepare_tools = prepare_tools
364
368
365
- self ._output_toolset = OutputToolset [AgentDepsT ](self ._output_schema , max_retries = self ._max_result_retries )
366
- self ._function_toolset = FunctionToolset [AgentDepsT ](tools , max_retries = retries )
369
+ self ._output_toolset = OutputToolset (self ._output_schema , max_retries = self ._max_result_retries )
370
+ self ._function_toolset = FunctionToolset (tools , max_retries = retries )
371
+ self ._user_toolsets = toolsets or ()
372
+ # TODO: Set max_retries on MCPServer
373
+ self ._mcp_servers = mcp_servers
367
374
368
375
# This will raise errors for any name conflicts
369
- # TODO: Also include toolsets (not mcp_serves as we won't have tool defs yet)
370
- CombinedToolset [AgentDepsT ]([self ._output_toolset , self ._function_toolset ])
371
-
372
- # TODO: Set max_retries on MCPServer
373
- toolset = CombinedToolset [AgentDepsT ]([self ._function_toolset , * toolsets , * mcp_servers ])
374
- if prepare_tools :
375
- toolset = PreparedToolset [AgentDepsT ](toolset , prepare_tools )
376
- self ._toolset = toolset
376
+ self ._toolset = CombinedToolset (
377
+ [self ._output_toolset , self ._function_toolset , * self ._user_toolsets , * self ._mcp_servers ]
378
+ )
377
379
378
380
self .history_processors = history_processors or []
379
381
@@ -395,6 +397,7 @@ async def run(
395
397
usage_limits : _usage .UsageLimits | None = None ,
396
398
usage : _usage .Usage | None = None ,
397
399
infer_name : bool = True ,
400
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
398
401
) -> AgentRunResult [OutputDataT ]: ...
399
402
400
403
@overload
@@ -410,6 +413,7 @@ async def run(
410
413
usage_limits : _usage .UsageLimits | None = None ,
411
414
usage : _usage .Usage | None = None ,
412
415
infer_name : bool = True ,
416
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
413
417
) -> AgentRunResult [RunOutputDataT ]: ...
414
418
415
419
@overload
@@ -426,6 +430,7 @@ async def run(
426
430
usage_limits : _usage .UsageLimits | None = None ,
427
431
usage : _usage .Usage | None = None ,
428
432
infer_name : bool = True ,
433
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
429
434
) -> AgentRunResult [RunOutputDataT ]: ...
430
435
431
436
async def run (
@@ -440,6 +445,7 @@ async def run(
440
445
usage_limits : _usage .UsageLimits | None = None ,
441
446
usage : _usage .Usage | None = None ,
442
447
infer_name : bool = True ,
448
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
443
449
** _deprecated_kwargs : Never ,
444
450
) -> AgentRunResult [Any ]:
445
451
"""Run the agent with a user prompt in async mode.
@@ -470,6 +476,7 @@ async def main():
470
476
usage_limits: Optional limits on model request count or token usage.
471
477
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
472
478
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
479
+ toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent.
473
480
474
481
Returns:
475
482
The result of the run.
@@ -494,6 +501,7 @@ async def main():
494
501
model_settings = model_settings ,
495
502
usage_limits = usage_limits ,
496
503
usage = usage ,
504
+ toolsets = toolsets ,
497
505
) as agent_run :
498
506
async for _ in agent_run :
499
507
pass
@@ -514,6 +522,7 @@ def iter(
514
522
usage_limits : _usage .UsageLimits | None = None ,
515
523
usage : _usage .Usage | None = None ,
516
524
infer_name : bool = True ,
525
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
517
526
** _deprecated_kwargs : Never ,
518
527
) -> AbstractAsyncContextManager [AgentRun [AgentDepsT , OutputDataT ]]: ...
519
528
@@ -530,6 +539,7 @@ def iter(
530
539
usage_limits : _usage .UsageLimits | None = None ,
531
540
usage : _usage .Usage | None = None ,
532
541
infer_name : bool = True ,
542
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
533
543
** _deprecated_kwargs : Never ,
534
544
) -> AbstractAsyncContextManager [AgentRun [AgentDepsT , RunOutputDataT ]]: ...
535
545
@@ -547,6 +557,7 @@ def iter(
547
557
usage_limits : _usage .UsageLimits | None = None ,
548
558
usage : _usage .Usage | None = None ,
549
559
infer_name : bool = True ,
560
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
550
561
) -> AbstractAsyncContextManager [AgentRun [AgentDepsT , Any ]]: ...
551
562
552
563
@asynccontextmanager
@@ -562,6 +573,7 @@ async def iter(
562
573
usage_limits : _usage .UsageLimits | None = None ,
563
574
usage : _usage .Usage | None = None ,
564
575
infer_name : bool = True ,
576
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
565
577
** _deprecated_kwargs : Never ,
566
578
) -> AsyncIterator [AgentRun [AgentDepsT , Any ]]:
567
579
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
@@ -636,6 +648,7 @@ async def main():
636
648
usage_limits: Optional limits on model request count or token usage.
637
649
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
638
650
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
651
+ toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent.
639
652
640
653
Returns:
641
654
The result of the run.
@@ -693,7 +706,11 @@ async def main():
693
706
run_step = state .run_step ,
694
707
)
695
708
696
- toolset = CombinedToolset ([output_toolset , self ._toolset ])
709
+ user_toolsets = self ._user_toolsets if toolsets is None else toolsets
710
+ toolset = CombinedToolset ([self ._function_toolset , * user_toolsets , * self ._mcp_servers ])
711
+ if self ._prepare_tools :
712
+ toolset = PreparedToolset (toolset , self ._prepare_tools )
713
+ toolset = CombinedToolset ([output_toolset , toolset ])
697
714
run_toolset = await toolset .prepare_for_run (run_context )
698
715
699
716
model_settings = merge_model_settings (self .model_settings , model_settings )
@@ -814,6 +831,7 @@ def run_sync(
814
831
usage_limits : _usage .UsageLimits | None = None ,
815
832
usage : _usage .Usage | None = None ,
816
833
infer_name : bool = True ,
834
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
817
835
) -> AgentRunResult [OutputDataT ]: ...
818
836
819
837
@overload
@@ -829,6 +847,7 @@ def run_sync(
829
847
usage_limits : _usage .UsageLimits | None = None ,
830
848
usage : _usage .Usage | None = None ,
831
849
infer_name : bool = True ,
850
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
832
851
) -> AgentRunResult [RunOutputDataT ]: ...
833
852
834
853
@overload
@@ -845,6 +864,7 @@ def run_sync(
845
864
usage_limits : _usage .UsageLimits | None = None ,
846
865
usage : _usage .Usage | None = None ,
847
866
infer_name : bool = True ,
867
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
848
868
) -> AgentRunResult [RunOutputDataT ]: ...
849
869
850
870
def run_sync (
@@ -859,6 +879,7 @@ def run_sync(
859
879
usage_limits : _usage .UsageLimits | None = None ,
860
880
usage : _usage .Usage | None = None ,
861
881
infer_name : bool = True ,
882
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
862
883
** _deprecated_kwargs : Never ,
863
884
) -> AgentRunResult [Any ]:
864
885
"""Synchronously run the agent with a user prompt.
@@ -888,6 +909,7 @@ def run_sync(
888
909
usage_limits: Optional limits on model request count or token usage.
889
910
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
890
911
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
912
+ toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent.
891
913
892
914
Returns:
893
915
The result of the run.
@@ -914,6 +936,7 @@ def run_sync(
914
936
usage_limits = usage_limits ,
915
937
usage = usage ,
916
938
infer_name = False ,
939
+ toolsets = toolsets ,
917
940
)
918
941
)
919
942
@@ -929,6 +952,7 @@ def run_stream(
929
952
usage_limits : _usage .UsageLimits | None = None ,
930
953
usage : _usage .Usage | None = None ,
931
954
infer_name : bool = True ,
955
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
932
956
) -> AbstractAsyncContextManager [result .StreamedRunResult [AgentDepsT , OutputDataT ]]: ...
933
957
934
958
@overload
@@ -944,6 +968,7 @@ def run_stream(
944
968
usage_limits : _usage .UsageLimits | None = None ,
945
969
usage : _usage .Usage | None = None ,
946
970
infer_name : bool = True ,
971
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
947
972
) -> AbstractAsyncContextManager [result .StreamedRunResult [AgentDepsT , RunOutputDataT ]]: ...
948
973
949
974
@overload
@@ -960,6 +985,7 @@ def run_stream(
960
985
usage_limits : _usage .UsageLimits | None = None ,
961
986
usage : _usage .Usage | None = None ,
962
987
infer_name : bool = True ,
988
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
963
989
) -> AbstractAsyncContextManager [result .StreamedRunResult [AgentDepsT , RunOutputDataT ]]: ...
964
990
965
991
@asynccontextmanager
@@ -975,6 +1001,7 @@ async def run_stream( # noqa C901
975
1001
usage_limits : _usage .UsageLimits | None = None ,
976
1002
usage : _usage .Usage | None = None ,
977
1003
infer_name : bool = True ,
1004
+ toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
978
1005
** _deprecated_kwargs : Never ,
979
1006
) -> AsyncIterator [result .StreamedRunResult [AgentDepsT , Any ]]:
980
1007
"""Run the agent with a user prompt in async mode, returning a streamed response.
@@ -1002,6 +1029,7 @@ async def main():
1002
1029
usage_limits: Optional limits on model request count or token usage.
1003
1030
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
1004
1031
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
1032
+ toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent.
1005
1033
1006
1034
Returns:
1007
1035
The result of the run.
@@ -1032,6 +1060,7 @@ async def main():
1032
1060
usage_limits = usage_limits ,
1033
1061
usage = usage ,
1034
1062
infer_name = False ,
1063
+ toolsets = toolsets ,
1035
1064
) as agent_run :
1036
1065
first_node = agent_run .next_node # start with the first node
1037
1066
assert isinstance (first_node , _agent_graph .UserPromptNode ) # the first node should be a user prompt node
0 commit comments