Skip to content

Feature/vertexai tool invocation #328

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 28, 2025
Merged
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

### Added

- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling.
- Added tool calling functionality to the LLM base class with OpenAI and VertexAI implementations, enabling structured parameter extraction and function calling.
- Added support for multi-vector collection in Qdrant driver.
- Added a `Pipeline.stream` method to stream pipeline progress.
- Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged.
Expand All @@ -13,7 +13,7 @@
### Changed

- Improved log output readability in Retrievers and GraphRAG and added embedded vector to retriever result metadata for debugging.
- Switched from pygraphviz to neo4j-viz
- Switched from pygraphviz to neo4j-viz
- Renders interactive graph now on HTML instead of PNG
- Removed `get_pygraphviz_graph` method

Expand Down
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ are listed in [the last section of this file](#customize).
- [System Instruction](./customize/llms/llm_with_system_instructions.py)

- [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py)
- [Tool Calling with VertexAI](./customize/llms/vertexai_tool_calls.py)


### Prompts
Expand Down
95 changes: 95 additions & 0 deletions examples/customize/llms/vertexai_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Example showing how to use VertexAI tool calls with parameter extraction.
Both synchronous and asynchronous examples are provided.
"""

import asyncio

from dotenv import load_dotenv
from vertexai.generative_models import GenerationConfig

from neo4j_graphrag.llm import VertexAILLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter

# Load environment variables from .env file
load_dotenv()


# Create a custom Tool implementation for person info extraction
parameters = ObjectParameter(
description="Parameters for extracting person information",
properties={
"name": StringParameter(description="The person's full name"),
"age": IntegerParameter(description="The person's age"),
"occupation": StringParameter(description="The person's occupation"),
},
required_properties=["name"],
additional_properties=False,
)


def run_tool(name: str, age: int, occupation: str) -> str:
"""A simple function that summarizes person information from input parameters."""
return f"Found person {name} with age {age} and occupation {occupation}"


person_info_tool = Tool(
name="extract_person_info",
description="Extract information about a person from text",
parameters=parameters,
execute_func=run_tool,
)

# Create the tool instance
TOOLS = [person_info_tool]


def process_tool_call(response: ToolCallResponse) -> str:
"""Process the tool call response and return the extracted parameters."""
if not response.tool_calls:
raise ValueError("No tool calls found in response")

tool_call = response.tool_calls[0]
print(f"\nTool called: {tool_call.name}")
print(f"Arguments: {tool_call.arguments}")
print(f"Additional content: {response.content or 'None'}")
return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return]


async def main() -> None:
# Initialize the VertexAI LLM
generation_config = GenerationConfig(temperature=0.0)
llm = VertexAILLM(
model_name="gemini-1.5-flash-001",
generation_config=generation_config,
)

# Example text containing information about a person
text = "Stella Hane is a 35-year-old software engineer who loves coding."

print("\n=== Synchronous Tool Call ===")
# Make a synchronous tool call
sync_response = llm.invoke_with_tools(
input=f"Extract information about the person from this text: {text}",
tools=TOOLS,
)
sync_result = process_tool_call(sync_response)
print("\n=== Synchronous Tool Call Result ===")
print(sync_result)

print("\n=== Asynchronous Tool Call ===")
# Make an asynchronous tool call with a different text
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
async_response = await llm.ainvoke_with_tools(
input=f"Extract information about the person from this text: {text2}",
tools=TOOLS,
)
async_result = process_tool_call(async_response)
print("\n=== Asynchronous Tool Call Result ===")
print(async_result)


if __name__ == "__main__":
# Run the async main function
asyncio.run(main())
120 changes: 118 additions & 2 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,33 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, List, Optional, Union, cast
from typing import Any, List, Optional, Union, cast, Sequence

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import BaseMessage, LLMResponse, MessageList
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMResponse,
MessageList,
ToolCall,
ToolCallResponse,
)
from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.tool import Tool
from neo4j_graphrag.types import LLMMessage

try:
from vertexai.generative_models import (
Content,
FunctionCall,
FunctionDeclaration,
GenerationResponse,
GenerativeModel,
Part,
ResponseValidationError,
Tool as VertexAITool,
)
except ImportError:
GenerativeModel = None
Expand Down Expand Up @@ -176,3 +187,108 @@ async def ainvoke(
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)

def _to_vertexai_tool(self, tool: Tool) -> VertexAITool:
return VertexAITool(
function_declarations=[
FunctionDeclaration(
name=tool.get_name(),
description=tool.get_description(),
parameters=tool.get_parameters(exclude=["additional_properties"]),
)
]
)

def _get_llm_tools(
self, tools: Optional[Sequence[Tool]]
) -> Optional[list[VertexAITool]]:
if not tools:
return None
return [self._to_vertexai_tool(tool) for tool in tools]

def _get_model(
self,
system_instruction: Optional[str] = None,
tools: Optional[Sequence[Tool]] = None,
) -> GenerativeModel:
system_message = [system_instruction] if system_instruction is not None else []
vertex_ai_tools = self._get_llm_tools(tools)
model = GenerativeModel(
model_name=self.model_name,
system_instruction=system_message,
tools=vertex_ai_tools,
**self.options,
)
return model

async def _acall_llm(
self,
input: str,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
tools: Optional[Sequence[Tool]] = None,
) -> GenerationResponse:
model = self._get_model(system_instruction=system_instruction, tools=tools)
messages = self.get_messages(input, message_history)
response = await model.generate_content_async(messages, **self.model_params)
return response

def _call_llm(
self,
input: str,
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
tools: Optional[Sequence[Tool]] = None,
) -> GenerationResponse:
model = self._get_model(system_instruction=system_instruction, tools=tools)
messages = self.get_messages(input, message_history)
response = model.generate_content(messages, **self.model_params)
return response

def _to_tool_call(self, function_call: FunctionCall) -> ToolCall:
return ToolCall(
name=function_call.name,
arguments=function_call.args,
)

def _parse_tool_response(self, response: GenerationResponse) -> ToolCallResponse:
function_calls = response.candidates[0].function_calls
return ToolCallResponse(
tool_calls=[self._to_tool_call(f) for f in function_calls],
content=None,
)

def _parse_content_response(self, response: GenerationResponse) -> LLMResponse:
return LLMResponse(
content=response.text,
)

async def ainvoke_with_tools(
self,
input: str,
tools: Sequence[Tool],
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
response = await self._acall_llm(
input,
message_history=message_history,
system_instruction=system_instruction,
tools=tools,
)
return self._parse_tool_response(response)

def invoke_with_tools(
self,
input: str,
tools: Sequence[Tool],
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
response = self._call_llm(
input,
message_history=message_history,
system_instruction=system_instruction,
tools=tools,
)
return self._parse_tool_response(response)
18 changes: 10 additions & 8 deletions src/neo4j_graphrag/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,21 @@ def _preprocess_properties(cls, values: dict[str, Any]) -> dict[str, Any]:
values["properties"] = new_props
return values

def model_dump_tool(self) -> Dict[str, Any]:
def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
exclude = exclude or []
properties_dict: Dict[str, Any] = {}
for name, param in self.properties.items():
if name in exclude:
continue
properties_dict[name] = param.model_dump_tool()

result = super().model_dump_tool()
result["properties"] = properties_dict

if self.required_properties:
if self.required_properties and "required" not in exclude:
result["required"] = self.required_properties

if not self.additional_properties:
if not self.additional_properties and "additional_properties" not in exclude:
result["additionalProperties"] = False

return result
Expand Down Expand Up @@ -242,22 +245,21 @@ def get_description(self) -> str:
"""
return self._description

def get_parameters(self) -> Dict[str, Any]:
def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
"""Get the parameters the tool accepts in a dictionary format suitable for LLM providers.

Returns:
Dict[str, Any]: Dictionary containing parameter schema information.
"""
return self._parameters.model_dump_tool()
return self._parameters.model_dump_tool(exclude)

def execute(self, query: str, **kwargs: Any) -> Any:
def execute(self, **kwargs: Any) -> Any:
"""Execute the tool with the given query and additional parameters.

Args:
query (str): The query or input for the tool to process.
**kwargs (Any): Additional parameters for the tool.

Returns:
Any: The result of the tool execution.
"""
return self._execute_func(query, **kwargs)
return self._execute_func(**kwargs)
27 changes: 27 additions & 0 deletions tests/unit/llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pytest

from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter


class TestTool(Tool):
"""Test tool for unit tests."""

def __init__(self, name: str = "test_tool", description: str = "A test tool"):
parameters = ObjectParameter(
description="Test parameters",
properties={"param1": StringParameter(description="Test parameter")},
required_properties=["param1"],
additional_properties=False,
)

super().__init__(
name=name,
description=description,
parameters=parameters,
execute_func=lambda **kwargs: kwargs,
)


@pytest.fixture
def test_tool() -> Tool:
return TestTool()
Loading