diff --git a/litellm/__init__.py b/litellm/__init__.py index 922a0afcedef..35fa2629d1c1 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -813,7 +813,10 @@ def add_known_models(): from .cost_calculator import completion_cost from litellm.litellm_core_utils.litellm_logging import Logging, modify_integration from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider -from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls +from litellm.litellm_core_utils.core_helpers import ( + remove_index_from_tool_calls, + remove_items_at_indices, +) from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from .utils import ( client, diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 28a0097c30de..a57496f17f16 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,6 +1,6 @@ # What is this? ## Helper utilities -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union, Iterable import httpx @@ -67,7 +67,16 @@ def remove_index_from_tool_calls( ): # Type guard to ensure it's a dict tool_call.pop("index", None) - return + return + + +def remove_items_at_indices(items: Optional[List[Any]], indices: Iterable[int]) -> None: + """Remove items from a list in-place by index""" + if items is None: + return + for index in sorted(set(indices), reverse=True): + if 0 <= index < len(items): + items.pop(index) def add_missing_spend_metadata_to_litellm_metadata( diff --git a/litellm/vector_stores/vector_store_registry.py b/litellm/vector_stores/vector_store_registry.py index a045406d78a2..78dd8ffd5a4c 100644 --- a/litellm/vector_stores/vector_store_registry.py +++ b/litellm/vector_stores/vector_store_registry.py @@ -3,6 +3,8 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional +from litellm.litellm_core_utils.core_helpers import remove_items_at_indices + from litellm._logging import verbose_logger from litellm.types.vector_stores import ( LiteLLM_ManagedVectorStore, @@ -55,9 +57,26 @@ def pop_vector_store_ids_to_run( vector_store_ids = non_default_params.pop("vector_store_ids", None) or [] # 2. check if vector_store_ids is provided as a tool in the request - vector_store_ids = self._get_vector_store_ids_from_tool_calls( - tools=tools, vector_store_ids=vector_store_ids - ) + if tools: + tools_to_remove: List[int] = [] + for i, tool in enumerate(tools): + tool_vector_store_ids: List[str] = tool.get("vector_store_ids", []) + if len(tool_vector_store_ids) == 0: + continue + + vector_store_ids.extend(tool_vector_store_ids) + + # remove the tool if all vector_store_ids are recognised in the registry + recognised = all( + any(vs.get("vector_store_id") == vs_id for vs in self.vector_stores) + for vs_id in tool_vector_store_ids + ) + if recognised: + tools_to_remove.append(i) + + # remove recognised tools from the original list + remove_items_at_indices(tools, tools_to_remove) + return vector_store_ids def get_vector_store_to_run( diff --git a/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py b/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py index a8afe9ebbaa1..5706de85030b 100644 --- a/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py +++ b/tests/logging_callback_tests/test_bedrock_knowledgebase_hook.py @@ -161,9 +161,12 @@ async def test_openai_with_knowledge_base_mock_openai(setup_vector_store_registr # Verify the API was called mock_client.assert_called_once() request_body = mock_client.call_args.kwargs - + # Verify the request contains messages with knowledge base context assert "messages" in request_body + # The original tools field should be removed once the vector store ids + # have been processed by LiteLLM. + assert "tools" not in request_body messages = request_body["messages"] # We expect at least 2 messages: