Skip to content

Commit 1cb7f32

Browse files
committed
Add LangChainToolset
1 parent 2e200ac commit 1cb7f32

File tree

4 files changed

+38
-25
lines changed

4 files changed

+38
-25
lines changed

pydantic_ai_slim/pydantic_ai/ext/langchain.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pydantic.json_schema import JsonSchemaValue
44

55
from pydantic_ai.tools import Tool
6+
from pydantic_ai.toolset import FunctionToolset
67

78

89
class LangChainTool(Protocol):
@@ -23,7 +24,7 @@ def description(self) -> str: ...
2324
def run(self, *args: Any, **kwargs: Any) -> str: ...
2425

2526

26-
__all__ = ('tool_from_langchain',)
27+
__all__ = ('tool_from_langchain', 'LangChainToolset')
2728

2829

2930
def tool_from_langchain(langchain_tool: LangChainTool) -> Tool:
@@ -59,3 +60,10 @@ def proxy(*args: Any, **kwargs: Any) -> str:
5960
description=function_description,
6061
json_schema=schema,
6162
)
63+
64+
65+
class LangChainToolset(FunctionToolset):
66+
"""A toolset that wraps LangChain tools."""
67+
68+
def __init__(self, tools: list[LangChainTool]):
69+
super().__init__([tool_from_langchain(tool) for tool in tools])

pydantic_ai_slim/pydantic_ai/toolset.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class AbstractToolset(ABC, Generic[AgentDepsT]):
4646

4747
@property
4848
def name(self) -> str:
49-
return self.__class__.__name__
49+
return self.__class__.__name__.replace('Toolset', ' toolset')
5050

5151
@property
5252
def name_conflict_hint(self) -> str:
@@ -110,10 +110,6 @@ class FunctionToolset(AbstractToolset[AgentDepsT]):
110110
max_retries: int = field(default=1)
111111
tools: dict[str, Tool[Any]] = field(default_factory=dict)
112112

113-
@property
114-
def name(self) -> str:
115-
return 'FunctionToolset'
116-
117113
def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1):
118114
self.max_retries = max_retries
119115
self.tools = {}

tests/ext/test_langchain.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic.json_schema import JsonSchemaValue
77

88
from pydantic_ai import Agent
9-
from pydantic_ai.ext.langchain import tool_from_langchain
9+
from pydantic_ai.ext.langchain import LangChainToolset, tool_from_langchain
1010

1111

1212
@dataclass
@@ -49,31 +49,40 @@ def get_input_jsonschema(self) -> JsonSchemaValue:
4949
}
5050

5151

52-
def test_langchain_tool_conversion():
53-
langchain_tool = SimulatedLangChainTool(
54-
name='file_search',
55-
description='Recursively search for files in a subdirectory that match the regex pattern',
56-
args={
57-
'dir_path': {
58-
'default': '.',
59-
'description': 'Subdirectory to search in.',
60-
'title': 'Dir Path',
61-
'type': 'string',
62-
},
63-
'pattern': {
64-
'description': 'Unix shell regex, where * matches everything.',
65-
'title': 'Pattern',
66-
'type': 'string',
67-
},
52+
langchain_tool = SimulatedLangChainTool(
53+
name='file_search',
54+
description='Recursively search for files in a subdirectory that match the regex pattern',
55+
args={
56+
'dir_path': {
57+
'default': '.',
58+
'description': 'Subdirectory to search in.',
59+
'title': 'Dir Path',
60+
'type': 'string',
6861
},
69-
)
62+
'pattern': {
63+
'description': 'Unix shell regex, where * matches everything.',
64+
'title': 'Pattern',
65+
'type': 'string',
66+
},
67+
},
68+
)
69+
70+
71+
def test_langchain_tool_conversion():
7072
pydantic_tool = tool_from_langchain(langchain_tool)
7173

7274
agent = Agent('test', tools=[pydantic_tool], retries=7)
7375
result = agent.run_sync('foobar')
7476
assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}")
7577

7678

79+
def test_langchain_toolset():
80+
toolset = LangChainToolset([langchain_tool])
81+
agent = Agent('test', toolsets=[toolset], retries=7)
82+
result = agent.run_sync('foobar')
83+
assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}")
84+
85+
7786
def test_langchain_tool_no_additional_properties():
7887
langchain_tool = SimulatedLangChainTool(
7988
name='file_search',

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def test_tool_return_conflict():
589589
# this raises an error
590590
with pytest.raises(
591591
UserError,
592-
match="FunctionToolset defines a tool whose name conflicts with existing tool from OutputToolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.",
592+
match="Function toolset defines a tool whose name conflicts with existing tool from OutputToolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.",
593593
):
594594
Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool'))
595595

0 commit comments

Comments
 (0)