diff --git a/CHANGELOG.md b/CHANGELOG.md index 140b36bd2..cbd83dedd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ### Fixed - Fixed a bug where `spacy` and `rapidfuzz` needed to be installed even if not using the relevant entity resolvers. +- Fixed a bug where `VertexAILLM.(a)invoke_with_tools` called with multiple tools would raise an error. ### Changed diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index b8b00da5b..ebe9fec22 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -4,6 +4,7 @@ """ import asyncio +from typing import Optional from dotenv import load_dotenv from vertexai.generative_models import GenerationConfig @@ -17,7 +18,7 @@ # Create a custom Tool implementation for person info extraction -parameters = ObjectParameter( +person_tool_parameters = ObjectParameter( description="Parameters for extracting person information", properties={ "name": StringParameter(description="The person's full name"), @@ -29,7 +30,9 @@ ) -def run_tool(name: str, age: int, occupation: str) -> str: +def run_person_tool( + name: str, age: Optional[int] = None, occupation: Optional[str] = None +) -> str: """A simple function that summarizes person information from input parameters.""" return f"Found person {name} with age {age} and occupation {occupation}" @@ -37,12 +40,40 @@ def run_tool(name: str, age: int, occupation: str) -> str: person_info_tool = Tool( name="extract_person_info", description="Extract information about a person from text", - parameters=parameters, - execute_func=run_tool, + parameters=person_tool_parameters, + execute_func=run_person_tool, +) + +company_tool_parameters = ObjectParameter( + description="Parameters for extracting company information", + properties={ + "name": StringParameter(description="The company's full name"), + "industry": StringParameter(description="The company's industry"), + "creation_year": IntegerParameter(description="The company's creation year"), + }, + required_properties=["name"], + additional_properties=False, +) + + +def run_company_tool( + name: str, industry: Optional[str] = None, creation_year: Optional[int] = None +) -> str: + """A simple function that summarizes company information from input parameters.""" + return ( + f"Found company {name} operating in industry {industry} since {creation_year}" + ) + + +company_info_tool = Tool( + name="extract_company_info", + description="Extract information about a company from text", + parameters=company_tool_parameters, + execute_func=run_company_tool, ) # Create the tool instance -TOOLS = [person_info_tool] +TOOLS = [person_info_tool, company_info_tool] def process_tool_call(response: ToolCallResponse) -> str: @@ -54,24 +85,34 @@ def process_tool_call(response: ToolCallResponse) -> str: print(f"\nTool called: {tool_call.name}") print(f"Arguments: {tool_call.arguments}") print(f"Additional content: {response.content or 'None'}") - return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] + if tool_call.name == "extract_person_info": + return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] + elif tool_call.name == "extract_company_info": + return str(company_info_tool.execute(**tool_call.arguments)) + else: + raise ValueError("Unknown tool call") async def main() -> None: # Initialize the VertexAI LLM generation_config = GenerationConfig(temperature=0.0) llm = VertexAILLM( - model_name="gemini-1.5-flash-001", + model_name="gemini-2.0-flash-001", generation_config=generation_config, + # tool_config=ToolConfig( + # function_calling_config=ToolConfig.FunctionCallingConfig( + # mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + # # allowed_function_names=["extract_person_info"], + # )) ) - # Example text containing information about a person - text = "Stella Hane is a 35-year-old software engineer who loves coding." + # Example text containing information about a company + text1 = "Neo4j is a software company created in 2007" print("\n=== Synchronous Tool Call ===") # Make a synchronous tool call sync_response = llm.invoke_with_tools( - input=f"Extract information about the person from this text: {text}", + input=f"Extract information about the person from this text: {text1}", tools=TOOLS, ) sync_result = process_tool_call(sync_response) @@ -79,7 +120,7 @@ async def main() -> None: print(sync_result) print("\n=== Asynchronous Tool Call ===") - # Make an asynchronous tool call with a different text + # Make an asynchronous tool call with a different text about a person text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." async_response = await llm.ainvoke_with_tools( input=f"Extract information about the person from this text: {text2}", diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 100ff99ab..39d483915 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -40,6 +40,7 @@ Part, ResponseValidationError, Tool as VertexAITool, + ToolConfig, ) except ImportError: GenerativeModel = None @@ -137,20 +138,17 @@ def invoke( Returns: LLMResponse: The response from the LLM. """ - system_message = [system_instruction] if system_instruction is not None else [] - self.model = GenerativeModel( - model_name=self.model_name, - system_instruction=system_message, - **self.options, + model = self._get_model( + system_instruction=system_instruction, ) try: if isinstance(message_history, MessageHistory): message_history = message_history.messages - messages = self.get_messages(input, message_history) - response = self.model.generate_content(messages, **self.model_params) - return LLMResponse(content=response.text) + options = self._get_call_params(input, message_history, tools=None) + response = model.generate_content(**options) + return self._parse_content_response(response) except ResponseValidationError as e: - raise LLMGenerationError(e) + raise LLMGenerationError("Error calling VertexAILLM") from e async def ainvoke( self, @@ -172,31 +170,20 @@ async def ainvoke( try: if isinstance(message_history, MessageHistory): message_history = message_history.messages - system_message = ( - [system_instruction] if system_instruction is not None else [] - ) - self.model = GenerativeModel( - model_name=self.model_name, - system_instruction=system_message, - **self.options, + model = self._get_model( + system_instruction=system_instruction, ) - messages = self.get_messages(input, message_history) - response = await self.model.generate_content_async( - messages, **self.model_params - ) - return LLMResponse(content=response.text) + options = self._get_call_params(input, message_history, tools=None) + response = await model.generate_content_async(**options) + return self._parse_content_response(response) except ResponseValidationError as e: - raise LLMGenerationError(e) - - def _to_vertexai_tool(self, tool: Tool) -> VertexAITool: - return VertexAITool( - function_declarations=[ - FunctionDeclaration( - name=tool.get_name(), - description=tool.get_description(), - parameters=tool.get_parameters(exclude=["additional_properties"]), - ) - ] + raise LLMGenerationError("Error calling VertexAILLM") from e + + def _to_vertexai_function_declaration(self, tool: Tool) -> FunctionDeclaration: + return FunctionDeclaration( + name=tool.get_name(), + description=tool.get_description(), + parameters=tool.get_parameters(exclude=["additional_properties"]), ) def _get_llm_tools( @@ -204,23 +191,50 @@ def _get_llm_tools( ) -> Optional[list[VertexAITool]]: if not tools: return None - return [self._to_vertexai_tool(tool) for tool in tools] + return [ + VertexAITool( + function_declarations=[ + self._to_vertexai_function_declaration(tool) for tool in tools + ] + ) + ] def _get_model( self, system_instruction: Optional[str] = None, - tools: Optional[Sequence[Tool]] = None, ) -> GenerativeModel: system_message = [system_instruction] if system_instruction is not None else [] - vertex_ai_tools = self._get_llm_tools(tools) model = GenerativeModel( model_name=self.model_name, system_instruction=system_message, - tools=vertex_ai_tools, - **self.options, ) return model + def _get_call_params( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]], + tools: Optional[Sequence[Tool]], + ) -> dict[str, Any]: + options = dict(self.options) + if tools: + # we want a tool back, remove generation_config if defined + options.pop("generation_config", None) + options["tools"] = self._get_llm_tools(tools) + if "tool_config" not in options: + options["tool_config"] = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=ToolConfig.FunctionCallingConfig.Mode.ANY, + ) + ) + else: + # no tools, remove tool_config if defined + options.pop("tool_config", None) + + messages = self.get_messages(input, message_history) + options["contents"] = messages + return options + async def _acall_llm( self, input: str, @@ -228,9 +242,9 @@ async def _acall_llm( system_instruction: Optional[str] = None, tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction=system_instruction, tools=tools) - messages = self.get_messages(input, message_history) - response = await model.generate_content_async(messages, **self.model_params) + model = self._get_model(system_instruction=system_instruction) + options = self._get_call_params(input, message_history, tools) + response = await model.generate_content_async(**options) return response def _call_llm( @@ -240,9 +254,9 @@ def _call_llm( system_instruction: Optional[str] = None, tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction=system_instruction, tools=tools) - messages = self.get_messages(input, message_history) - response = model.generate_content(messages, **self.model_params) + model = self._get_model(system_instruction=system_instruction) + options = self._get_call_params(input, message_history, tools) + response = model.generate_content(**options) return response def _to_tool_call(self, function_call: FunctionCall) -> ToolCall: diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index b475efcc5..c937d2cb8 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -14,7 +14,6 @@ from __future__ import annotations from typing import cast -from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -50,19 +49,20 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: response = llm.invoke(input_text) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[] + model_name=model_name, + system_instruction=[], ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) - last_call = llm.model.generate_content.call_args_list[0] - content = last_call.args[0] + last_call = mock_model.generate_content.call_args_list[0] + content = last_call.kwargs["contents"] assert len(content) == 1 assert content[0].role == "user" assert content[0].parts[0].text == input_text @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") def test_vertexai_invoke_with_system_instruction( + mock_get_messages: MagicMock, GenerativeModelMock: MagicMock, ) -> None: system_instruction = "You are a helpful assistant." @@ -72,16 +72,21 @@ def test_vertexai_invoke_with_system_instruction( mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content.return_value = mock_response + + mock_get_messages.return_value = [{"text": "some text"}] + model_params = {"temperature": 0.5} llm = VertexAILLM(model_name, model_params) response = llm.invoke(input_text, system_instruction=system_instruction) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction] + model_name=model_name, + system_instruction=[system_instruction], + ) + mock_model.generate_content.assert_called_once_with( + contents=[{"text": "some text"}] ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") @@ -110,12 +115,11 @@ def test_vertexai_invoke_with_message_history_and_system_instruction( ) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[system_instruction] + model_name=model_name, + system_instruction=[system_instruction], ) - user_message = mock.ANY - llm.model.generate_content.assert_called_once_with(user_message, **model_params) - last_call = llm.model.generate_content.call_args_list[0] - content = last_call.args[0] + last_call = mock_model.generate_content.call_args_list[0] + content = last_call.kwargs["contents"] assert len(content) == 3 # question + 2 messages in history @@ -167,18 +171,22 @@ def test_vertexai_get_messages_validation_error(GenerativeModelMock: MagicMock) @pytest.mark.asyncio @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") -async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> None: +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM.get_messages") +async def test_vertexai_ainvoke_happy_path( + mock_get_messages: Mock, GenerativeModelMock: MagicMock +) -> None: mock_response = AsyncMock() mock_response.text = "Return text" mock_model = GenerativeModelMock.return_value mock_model.generate_content_async = AsyncMock(return_value=mock_response) + mock_get_messages.return_value = [{"text": "Return text"}] model_params = {"temperature": 0.5} llm = VertexAILLM("gemini-1.5-flash-001", model_params) input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) assert response.content == "Return text" - llm.model.generate_content_async.assert_awaited_once_with( - [mock.ANY], **model_params + mock_model.generate_content_async.assert_awaited_once_with( + contents=[{"text": "Return text"}] ) @@ -235,13 +243,17 @@ def test_vertexai_call_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None llm = VertexAILLM(model_name="gemini") tools = [test_tool] - res = llm._call_llm("my text", tools=tools) - assert isinstance(res, GenerationResponse) + with patch.object(llm, "_get_llm_tools", return_value=["my tools"]): + res = llm._call_llm("my text", tools=tools) + assert isinstance(res, GenerationResponse) - mock_model.assert_called_once_with( - system_instruction=None, - tools=tools, - ) + mock_model.assert_called_once_with( + system_instruction=None, + ) + calls = mock_generate_content.call_args_list + assert len(calls) == 1 + assert calls[0][1]["tools"] == ["my tools"] + assert calls[0][1]["tool_config"] is not None @patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response") @@ -292,6 +304,5 @@ async def test_vertexai_acall_llm_with_tools(mock_model: Mock, test_tool: Tool) res = await llm._acall_llm("my text", tools=tools) mock_model.assert_called_once_with( system_instruction=None, - tools=tools, ) assert isinstance(res, GenerationResponse)