Skip to content

Commit 017bcff

Browse files
committed
Implement tool calling for VertexAILLM
1 parent ce50e5f commit 017bcff

File tree

2 files changed

+120
-4
lines changed

2 files changed

+120
-4
lines changed

src/neo4j_graphrag/llm/vertexai_llm.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,33 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
from typing import Any, List, Optional, Union, cast
16+
from typing import Any, List, Optional, Union, cast, Sequence
1717

1818
from pydantic import ValidationError
1919

2020
from neo4j_graphrag.exceptions import LLMGenerationError
2121
from neo4j_graphrag.llm.base import LLMInterface
22-
from neo4j_graphrag.llm.types import BaseMessage, LLMResponse, MessageList
22+
from neo4j_graphrag.llm.types import (
23+
BaseMessage,
24+
LLMResponse,
25+
MessageList,
26+
ToolCall,
27+
ToolCallResponse,
28+
)
2329
from neo4j_graphrag.message_history import MessageHistory
30+
from neo4j_graphrag.tool import Tool
2431
from neo4j_graphrag.types import LLMMessage
2532

2633
try:
2734
from vertexai.generative_models import (
2835
Content,
36+
FunctionCall,
37+
FunctionDeclaration,
38+
GenerationResponse,
2939
GenerativeModel,
3040
Part,
3141
ResponseValidationError,
42+
Tool as VertexAITool,
3243
)
3344
except ImportError:
3445
GenerativeModel = None
@@ -176,3 +187,108 @@ async def ainvoke(
176187
return LLMResponse(content=response.text)
177188
except ResponseValidationError as e:
178189
raise LLMGenerationError(e)
190+
191+
def _to_vertexai_tool(self, tool: Tool) -> VertexAITool:
192+
return VertexAITool(
193+
function_declarations=[
194+
FunctionDeclaration(
195+
name=tool.get_name(),
196+
description=tool.get_description(),
197+
parameters=tool.get_parameters(),
198+
)
199+
]
200+
)
201+
202+
def get_tools(
203+
self, tools: Optional[Sequence[Tool]]
204+
) -> Optional[list[VertexAITool]]:
205+
if not tools:
206+
return None
207+
return [self._to_vertexai_tool(tool) for tool in tools]
208+
209+
def _get_model(
210+
self,
211+
system_instruction: Optional[str] = None,
212+
tools: Optional[Sequence[Tool]] = None,
213+
) -> GenerativeModel:
214+
system_message = [system_instruction] if system_instruction is not None else []
215+
vertex_ai_tools = self.get_tools(tools)
216+
model = GenerativeModel(
217+
model_name=self.model_name,
218+
system_instruction=system_message,
219+
tools=vertex_ai_tools,
220+
**self.options,
221+
)
222+
return model
223+
224+
async def _acall_llm(
225+
self,
226+
input: str,
227+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
228+
system_instruction: Optional[str] = None,
229+
tools: Optional[Sequence[Tool]] = None,
230+
) -> GenerationResponse:
231+
model = self._get_model(system_instruction, tools)
232+
messages = self.get_messages(input, message_history)
233+
response = await model.generate_content_async(messages, **self.model_params)
234+
return response
235+
236+
def _call_llm(
237+
self,
238+
input: str,
239+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
240+
system_instruction: Optional[str] = None,
241+
tools: Optional[Sequence[Tool]] = None,
242+
) -> GenerationResponse:
243+
model = self._get_model(system_instruction, tools)
244+
messages = self.get_messages(input, message_history)
245+
response = model.generate_content(messages, **self.model_params)
246+
return response
247+
248+
def _to_tool_call(self, function_call: FunctionCall) -> ToolCall:
249+
return ToolCall(
250+
name=function_call.name,
251+
arguments=function_call.args,
252+
)
253+
254+
def _parse_tool_response(self, response) -> ToolCallResponse:
255+
function_calls = response.candidates[0].function_calls
256+
return ToolCallResponse(
257+
tool_calls=[self._to_tool_call(f) for f in function_calls],
258+
content=None,
259+
)
260+
261+
def _parse_content_response(self, response) -> LLMResponse:
262+
return LLMResponse(
263+
content=response.text,
264+
)
265+
266+
async def ainvoke_with_tools(
267+
self,
268+
input: str,
269+
tools: Sequence[Tool],
270+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
271+
system_instruction: Optional[str] = None,
272+
) -> ToolCallResponse:
273+
response = await self._acall_llm(
274+
input,
275+
message_history=message_history,
276+
system_instruction=system_instruction,
277+
tools=tools,
278+
)
279+
return self._parse_tool_response(response)
280+
281+
def invoke_with_tools(
282+
self,
283+
input: str,
284+
tools: Sequence[Tool],
285+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
286+
system_instruction: Optional[str] = None,
287+
) -> ToolCallResponse:
288+
response = self._call_llm(
289+
input,
290+
message_history=message_history,
291+
system_instruction=system_instruction,
292+
tools=tools,
293+
)
294+
return self._parse_tool_response(response)

src/neo4j_graphrag/tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def model_dump_tool(self) -> Dict[str, Any]:
180180
if self.required_properties:
181181
result["required"] = self.required_properties
182182

183-
if not self.additional_properties:
184-
result["additionalProperties"] = False
183+
# if not self.additional_properties:
184+
# result["additionalProperties"] = False
185185

186186
return result
187187

0 commit comments

Comments
 (0)