diff --git a/CHANGELOG.md b/CHANGELOG.md index c5eb72654..49de506c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Support for Python 3.13 - Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default. - Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call). +- Added automatic rate limiting with retry logic and exponential backoff for all LLM providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely. ### Fixed diff --git a/docs/source/api.rst b/docs/source/api.rst index 55a5d1cc4..d8280b1cc 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -347,6 +347,28 @@ MistralAILLM :members: +Rate Limiting +============= + +RateLimitHandler +---------------- + +.. autoclass:: neo4j_graphrag.llm.rate_limit.RateLimitHandler + :members: + +RetryRateLimitHandler +--------------------- + +.. autoclass:: neo4j_graphrag.llm.rate_limit.RetryRateLimitHandler + :members: + +NoOpRateLimitHandler +-------------------- + +.. autoclass:: neo4j_graphrag.llm.rate_limit.NoOpRateLimitHandler + :members: + + PromptTemplate ============== @@ -473,6 +495,8 @@ Errors * :class:`neo4j_graphrag.exceptions.LLMGenerationError` + * :class:`neo4j_graphrag.exceptions.RateLimitError` + * :class:`neo4j_graphrag.exceptions.SchemaValidationError` * :class:`neo4j_graphrag.exceptions.PdfLoaderError` @@ -597,6 +621,13 @@ LLMGenerationError :show-inheritance: +RateLimitError +============== + +.. autoclass:: neo4j_graphrag.exceptions.RateLimitError + :show-inheritance: + + SchemaValidationError ===================== diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index 1ad76ef91..937a7099c 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -294,6 +294,91 @@ Here's an example using the Python Ollama client: See :ref:`llminterface`. +Rate Limit Handling +=================== + +All LLM implementations include automatic rate limiting that uses retry logic with exponential backoff by default. This feature helps handle API rate limits from LLM providers gracefully by automatically retrying failed requests with increasing wait times between attempts. + +Default Rate Limit Handler +-------------------------- + +Rate limiting is enabled by default for all LLM instances with the following configuration: + +- **Max attempts**: 3 +- **Min wait**: 1.0 seconds +- **Max wait**: 60.0 seconds +- **Multiplier**: 2.0 (exponential backoff) + +.. code:: python + + from neo4j_graphrag.llm import OpenAILLM + + # Rate limiting is automatically enabled + llm = OpenAILLM(model_name="gpt-4o") + + # The LLM will automatically retry on rate limit errors + response = llm.invoke("Hello, world!") + +.. note:: + + To change the default configuration of `RetryRateLimitHandler`: + + .. code:: python + + from neo4j_graphrag.llm import OpenAILLM + from neo4j_graphrag.llm.rate_limit import RetryRateLimitHandler + + # Customize rate limiting parameters + llm = OpenAILLM( + model_name="gpt-4o", + rate_limit_handler=RetryRateLimitHandler( + max_attempts=10, # Increase max retry attempts + min_wait=2.0, # Increase minimum wait time + max_wait=120.0, # Increase maximum wait time + multiplier=3.0 # More aggressive backoff + ) + ) + +Custom Rate Limiting +-------------------- + +You can customize the rate limiting behavior by creating your own rate limit handler: + +.. code:: python + + from neo4j_graphrag.llm import AnthropicLLM + from neo4j_graphrag.llm.rate_limit import RateLimitHandler + + class CustomRateLimitHandler(RateLimitHandler): + """Implement your custom rate limiting strategy.""" + # Implement required methods: handle_sync, handle_async + pass + + # Create custom rate limit handler and pass it to the LLM interface + custom_handler = CustomRateLimitHandler() + + llm = AnthropicLLM( + model_name="claude-3-sonnet-20240229", + rate_limit_handler=custom_handler, + ) + +Disabling Rate Limiting +----------------------- + +For high-throughput applications or when you handle rate limiting externally, you can disable it: + +.. code:: python + + from neo4j_graphrag.llm import CohereLLM, NoOpRateLimitHandler + + # Disable rate limiting completely + llm = CohereLLM( + model_name="command-r-plus", + rate_limit_handler=NoOpRateLimitHandler(), + ) + llm.invoke("Hello, world!") + + Configuring the Prompt ======================== diff --git a/poetry.lock b/poetry.lock index cbd375ac0..5cb194fcb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -4028,8 +4028,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4942,8 +4942,8 @@ grpcio = ">=1.41.0" httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ {version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""}, - {version = ">=1.26", markers = "python_version == \"3.12\""}, {version = ">=1.21,<2.1.0", markers = "python_version < \"3.10\""}, + {version = ">=1.26", markers = "python_version == \"3.12\""}, {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, ] portalocker = ">=2.7.0,<3.0.0" @@ -6281,7 +6281,7 @@ widechars = ["wcwidth"] name = "tenacity" version = "9.1.2" description = "Retry code until it succeeds" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138"}, @@ -7370,4 +7370,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.14" -content-hash = "f53f3dfff909ce5fadc0f38896354f2952cc22098bd2dcd043a7de8e89026375" +content-hash = "83b68416feaf289d06e1af48ec8b7a3ac20ec0585be6d80f5bb0fb5b7deda025" diff --git a/pyproject.toml b/pyproject.toml index 320fc11e0..b44c2fa64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ scipy = [ { version = "^1.13.0", python = ">=3.9,<3.13" }, { version = "^1.15.0", python = ">=3.13,<3.14" } ] +tenacity = "^9.1.2" [tool.poetry.group.dev.dependencies] urllib3 = "<2" diff --git a/src/neo4j_graphrag/exceptions.py b/src/neo4j_graphrag/exceptions.py index 681b20eec..9faffff99 100644 --- a/src/neo4j_graphrag/exceptions.py +++ b/src/neo4j_graphrag/exceptions.py @@ -138,3 +138,7 @@ class InvalidHybridSearchRankerError(Neo4jGraphRagError): class SearchQueryParseError(Neo4jGraphRagError): """Exception raised when there is a query parse error in the text search string.""" + + +class RateLimitError(LLMGenerationError): + """Exception raised when API rate limit is exceeded.""" diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index a9ece5ccb..3c4f65d9a 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -18,6 +18,13 @@ from .mistralai_llm import MistralAILLM from .ollama_llm import OllamaLLM from .openai_llm import AzureOpenAILLM, OpenAILLM +from .rate_limit import ( + RateLimitHandler, + NoOpRateLimitHandler, + RetryRateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from .types import LLMResponse from .vertexai_llm import VertexAILLM @@ -31,4 +38,10 @@ "VertexAILLM", "AzureOpenAILLM", "MistralAILLM", + # Rate limiting components + "RateLimitHandler", + "NoOpRateLimitHandler", + "RetryRateLimitHandler", + "rate_limit_handler", + "async_rate_limit_handler", ] diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 881156e3f..6bafef85b 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -19,6 +19,11 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -62,6 +67,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): try: @@ -71,7 +77,7 @@ def __init__( """Could not import Anthropic Python client. Please install it with `pip install "neo4j-graphrag[anthropic]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.anthropic = anthropic self.client = anthropic.Anthropic(**kwargs) self.async_client = anthropic.AsyncAnthropic(**kwargs) @@ -93,6 +99,7 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore + @rate_limit_handler def invoke( self, input: str, @@ -129,6 +136,7 @@ def invoke( except self.anthropic.APIError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 87d281794..cca710bc9 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -21,9 +21,14 @@ from neo4j_graphrag.types import LLMMessage from .types import LLMResponse, ToolCallResponse +from .rate_limit import ( + DEFAULT_RATE_LIMIT_HANDLER, +) from neo4j_graphrag.tool import Tool +from .rate_limit import RateLimitHandler + class LLMInterface(ABC): """Interface for large language models. @@ -31,6 +36,7 @@ class LLMInterface(ABC): Args: model_name (str): The name of the language model. model_params (Optional[dict]): Additional parameters passed to the model when text is sent to it. Defaults to None. + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. """ @@ -38,11 +44,17 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): self.model_name = model_name self.model_params = model_params or {} + if rate_limit_handler is not None: + self._rate_limit_handler = rate_limit_handler + else: + self._rate_limit_handler = DEFAULT_RATE_LIMIT_HANDLER + @abstractmethod def invoke( self, diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index ecddd53ea..7c3905500 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -20,6 +20,11 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -60,6 +65,7 @@ def __init__( self, model_name: str = "", model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ) -> None: try: @@ -69,7 +75,7 @@ def __init__( """Could not import cohere python client. Please install it with `pip install "neo4j-graphrag[cohere]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.cohere = cohere self.cohere_api_error = cohere.core.api_error.ApiError @@ -96,6 +102,7 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore + @rate_limit_handler def invoke( self, input: str, @@ -127,6 +134,7 @@ def invoke( content=res.message.content[0].text if res.message.content else "", ) + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 9e44287bc..ae2a6312f 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -21,6 +21,11 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -44,6 +49,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): """ @@ -52,6 +58,7 @@ def __init__( model_name (str): model_params (str): Parameters like temperature and such that will be passed to the chat completions endpoint + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. kwargs: All other parameters will be passed to the Mistral client. """ @@ -60,7 +67,7 @@ def __init__( """Could not import Mistral Python client. Please install it with `pip install "neo4j-graphrag[mistralai]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) api_key = kwargs.pop("api_key", None) if api_key is None: api_key = os.getenv("MISTRAL_API_KEY", "") @@ -86,6 +93,7 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return cast(list[Messages], messages) + @rate_limit_handler def invoke( self, input: str, @@ -124,6 +132,7 @@ def invoke( except SDKError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 5abb13d8f..6c4728888 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -23,6 +23,7 @@ from neo4j_graphrag.types import LLMMessage from .base import LLMInterface +from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler from .types import ( BaseMessage, LLMResponse, @@ -40,6 +41,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): try: @@ -49,7 +51,7 @@ def __init__( "Could not import ollama Python client. " "Please install it with `pip install ollama`." ) - super().__init__(model_name, model_params, **kwargs) + super().__init__(model_name, model_params, rate_limit_handler) self.ollama = ollama self.client = ollama.Client( **kwargs, @@ -78,6 +80,7 @@ def get_messages( messages.append(UserMessage(content=input).model_dump()) return messages # type: ignore + @rate_limit_handler def invoke( self, input: str, @@ -108,6 +111,7 @@ def invoke( except self.ollama.ResponseError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 1e0228e45..ed8af1958 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -39,6 +39,7 @@ from ..exceptions import LLMGenerationError from .base import LLMInterface +from .rate_limit import RateLimitHandler, rate_limit_handler, async_rate_limit_handler from .types import ( BaseMessage, LLMResponse, @@ -63,6 +64,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, ): """ Base class for OpenAI LLM. @@ -72,6 +74,7 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. """ try: import openai @@ -81,7 +84,7 @@ def __init__( Please install it with `pip install "neo4j-graphrag[openai]"`.""" ) self.openai = openai - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) def get_messages( self, @@ -124,6 +127,7 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: except AttributeError: raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") + @rate_limit_handler def invoke( self, input: str, @@ -158,6 +162,7 @@ def invoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + @rate_limit_handler def invoke_with_tools( self, input: str, @@ -232,6 +237,7 @@ def invoke_with_tools( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke( self, input: str, @@ -266,6 +272,7 @@ async def ainvoke( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + @async_rate_limit_handler async def ainvoke_with_tools( self, input: str, @@ -347,6 +354,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): """OpenAI LLM @@ -356,9 +364,10 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.client = self.openai.OpenAI(**kwargs) self.async_client = self.openai.AsyncOpenAI(**kwargs) @@ -369,6 +378,7 @@ def __init__( model_name: str, model_params: Optional[dict[str, Any]] = None, system_instruction: Optional[str] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): """Azure OpenAI LLM. Use this class when using an OpenAI model @@ -377,8 +387,9 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + rate_limit_handler (Optional[RateLimitHandler]): Handler for rate limiting. Defaults to retry with exponential backoff. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.client = self.openai.AzureOpenAI(**kwargs) self.async_client = self.openai.AsyncAzureOpenAI(**kwargs) diff --git a/src/neo4j_graphrag/llm/rate_limit.py b/src/neo4j_graphrag/llm/rate_limit.py new file mode 100644 index 000000000..af6dacfb9 --- /dev/null +++ b/src/neo4j_graphrag/llm/rate_limit.py @@ -0,0 +1,270 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import functools +import logging +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, TypeVar + +from neo4j_graphrag.exceptions import RateLimitError + +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + wait_random_exponential, + retry_if_exception_type, + before_sleep_log, +) + + +logger = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) +AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]]) + + +class RateLimitHandler(ABC): + """Abstract base class for rate limit handling strategies.""" + + @abstractmethod + def handle_sync(self, func: F) -> F: + """Apply rate limit handling to a synchronous function. + + Args: + func: The function to wrap with rate limit handling. + + Returns: + The wrapped function. + """ + pass + + @abstractmethod + def handle_async(self, func: AF) -> AF: + """Apply rate limit handling to an asynchronous function. + + Args: + func: The async function to wrap with rate limit handling. + + Returns: + The wrapped async function. + """ + pass + + +class NoOpRateLimitHandler(RateLimitHandler): + """A no-op rate limit handler that does not apply any rate limiting.""" + + def handle_sync(self, func: F) -> F: + """Return the function unchanged.""" + return func + + def handle_async(self, func: AF) -> AF: + """Return the async function unchanged.""" + return func + + +class RetryRateLimitHandler(RateLimitHandler): + """Rate limit handler using exponential backoff retry strategy. + + This handler uses tenacity for retry logic with exponential backoff. + + Args: + max_attempts: Maximum number of retry attempts. Defaults to 3. + min_wait: Minimum wait time between retries in seconds. Defaults to 1. + max_wait: Maximum wait time between retries in seconds. Defaults to 60. + multiplier: Exponential backoff multiplier. Defaults to 2. + jitter: Whether to add random jitter to retry delays to prevent thundering herd. Defaults to True. + """ + + def __init__( + self, + max_attempts: int = 3, + min_wait: float = 1.0, + max_wait: float = 60.0, + multiplier: float = 2.0, + jitter: bool = True, + ): + self.max_attempts = max_attempts + self.min_wait = min_wait + self.max_wait = max_wait + self.multiplier = multiplier + self.jitter = jitter + + def _get_wait_strategy(self) -> Any: + """Get the appropriate wait strategy based on jitter setting. + + Returns: + The configured wait strategy for tenacity retry. + """ + if self.jitter: + # Use built-in random exponential backoff with jitter + return wait_random_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ) + else: + # Use standard exponential backoff without jitter + return wait_exponential( + multiplier=self.multiplier, + min=self.min_wait, + max=self.max_wait, + ) + + def handle_sync(self, func: F) -> F: + """Apply retry logic to a synchronous function.""" + decorator = retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=self._get_wait_strategy(), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + return decorator(func) + + def handle_async(self, func: AF) -> AF: + """Apply retry logic to an asynchronous function.""" + decorator = retry( + retry=retry_if_exception_type(RateLimitError), + stop=stop_after_attempt(self.max_attempts), + wait=self._get_wait_strategy(), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + return decorator(func) + + +def is_rate_limit_error(exception: Exception) -> bool: + """Check if an exception is a rate limit error from any LLM provider. + + Args: + exception: The exception to check. + + Returns: + True if the exception indicates a rate limit error, False otherwise. + """ + # Already converted to RateLimitError + if isinstance(exception, RateLimitError): + return True + + error_type = type(exception).__name__.lower() + exception_str = str(exception).lower() + + # OpenAI - specific error type + if error_type == "ratelimiterror": + return True + + # Check for HTTP 429 status code (various providers) + if hasattr(exception, "status_code") and getattr(exception, "status_code") == 429: + return True + + if hasattr(exception, "response"): + response = getattr(exception, "response") + if hasattr(response, "status_code") and response.status_code == 429: + return True + + # Provider-specific error types with message checks + rate_limit_error_types = { + "apierror": "too many requests", # Anthropic, Cohere + "sdkerror": "too many requests", # MistralAI + "responseerror": "too many requests", # Ollama + "responsevalidationerror": "resource exhausted", # VertexAI (special case) + } + + if error_type in rate_limit_error_types: + required_message = rate_limit_error_types[error_type] + return required_message in exception_str + + return False + + +def convert_to_rate_limit_error(exception: Exception) -> RateLimitError: + """Convert a provider-specific rate limit exception to RateLimitError. + + Args: + exception: The original exception from the LLM provider. + + Returns: + A RateLimitError with the original exception message. + """ + return RateLimitError(f"Rate limit exceeded: {exception}") + + +def rate_limit_handler(func: F) -> F: + """Decorator to apply rate limit handling to synchronous methods. + + This decorator works with instance methods and uses the instance's rate limit handler. + + Args: + func: The function to wrap with rate limit handling. + + Returns: + The wrapped function. + """ + + @functools.wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler or default + active_handler = getattr( + self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER + ) + + def inner_func() -> Any: + try: + return func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return active_handler.handle_sync(inner_func)() + + return wrapper # type: ignore + + +def async_rate_limit_handler(func: AF) -> AF: + """Decorator to apply rate limit handling to asynchronous methods. + + This decorator works with instance methods and uses the instance's rate limit handler. + + Args: + func: The async function to wrap with rate limit handling. + + Returns: + The wrapped async function. + """ + + @functools.wraps(func) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + # Use instance handler or default + active_handler = getattr( + self, "_rate_limit_handler", DEFAULT_RATE_LIMIT_HANDLER + ) + + async def inner_func() -> Any: + try: + return await func(self, *args, **kwargs) + except Exception as e: + if is_rate_limit_error(e): + raise convert_to_rate_limit_error(e) + raise + + return await active_handler.handle_async(inner_func)() + + return wrapper # type: ignore + + +# Default rate limit handler instance +DEFAULT_RATE_LIMIT_HANDLER = RetryRateLimitHandler() diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 39d483915..5b772c35b 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -19,6 +19,11 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + rate_limit_handler, + async_rate_limit_handler, +) from neo4j_graphrag.llm.types import ( BaseMessage, LLMResponse, @@ -78,6 +83,7 @@ def __init__( model_name: str = "gemini-1.5-flash-001", model_params: Optional[dict[str, Any]] = None, system_instruction: Optional[str] = None, + rate_limit_handler: Optional[RateLimitHandler] = None, **kwargs: Any, ): if GenerativeModel is None or ResponseValidationError is None: @@ -85,7 +91,7 @@ def __init__( """Could not import Vertex AI Python client. Please install it with `pip install "neo4j-graphrag[google]"`.""" ) - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limit_handler) self.model_name = model_name self.system_instruction = system_instruction self.options = kwargs @@ -121,6 +127,7 @@ def get_messages( messages.append(Content(role="user", parts=[Part.from_text(input)])) return messages + @rate_limit_handler def invoke( self, input: str, @@ -150,6 +157,7 @@ def invoke( except ResponseValidationError as e: raise LLMGenerationError("Error calling VertexAILLM") from e + @async_rate_limit_handler async def ainvoke( self, input: str, diff --git a/tests/unit/llm/test_rate_limit.py b/tests/unit/llm/test_rate_limit.py new file mode 100644 index 000000000..f1f4b133b --- /dev/null +++ b/tests/unit/llm/test_rate_limit.py @@ -0,0 +1,175 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Awaitable + +import pytest +from unittest.mock import Mock +from tenacity import RetryError + +from neo4j_graphrag.llm.rate_limit import ( + RateLimitHandler, + NoOpRateLimitHandler, + DEFAULT_RATE_LIMIT_HANDLER, +) +from neo4j_graphrag.exceptions import RateLimitError + + +def test_default_handler_retries_sync() -> None: + call_count = 0 + + def mock_func() -> None: + nonlocal call_count + call_count += 1 + raise RateLimitError("Rate limit exceeded") + + wrapped_func = DEFAULT_RATE_LIMIT_HANDLER.handle_sync(mock_func) + + with pytest.raises(RetryError): + wrapped_func() + + assert call_count == 3 + + +@pytest.mark.asyncio +async def test_default_handler_retries_async() -> None: + call_count = 0 + + async def mock_func() -> None: + nonlocal call_count + call_count += 1 + raise RateLimitError("Rate limit exceeded") + + wrapped_func = DEFAULT_RATE_LIMIT_HANDLER.handle_async(mock_func) + + with pytest.raises(RetryError): + await wrapped_func() + + assert call_count == 3 + + +def test_other_errors_pass_through_sync() -> None: + call_count = 0 + + def mock_func() -> None: + nonlocal call_count + call_count += 1 + raise ValueError("Some other error") + + wrapped_func = DEFAULT_RATE_LIMIT_HANDLER.handle_sync(mock_func) + + with pytest.raises(ValueError): + wrapped_func() + + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_other_errors_pass_through_async() -> None: + call_count = 0 + + async def mock_func() -> None: + nonlocal call_count + call_count += 1 + raise ValueError("Some other error") + + wrapped_func = DEFAULT_RATE_LIMIT_HANDLER.handle_async(mock_func) + + with pytest.raises(ValueError): + await wrapped_func() + + assert call_count == 1 + + +def test_noop_handler_sync() -> None: + def mock_func() -> str: + return "test result" + + handler = NoOpRateLimitHandler() + wrapped_func = handler.handle_sync(mock_func) + + assert wrapped_func() == "test result" + assert wrapped_func is mock_func + + +@pytest.mark.asyncio +async def test_noop_handler_async() -> None: + async def mock_func() -> str: + return "async test result" + + handler = NoOpRateLimitHandler() + wrapped_func = handler.handle_async(mock_func) + + assert await wrapped_func() == "async test result" + assert wrapped_func is mock_func + + +def test_custom_handler_sync_retry_override() -> None: + call_count = 0 + + def mock_func() -> str: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RateLimitError("Rate limit exceeded") + return "success after custom retry" + + # Custom handler with single retry + def custom_handle_sync(func: Callable[[], Any]) -> Callable[[], Any]: + def wrapper() -> Any: + try: + return func() + except RateLimitError: + return func() # Retry once + + return wrapper + + handler = Mock(spec=RateLimitHandler) + handler.handle_sync = custom_handle_sync + + result = handler.handle_sync(mock_func)() + assert result == "success after custom retry" + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_custom_handler_async_retry_override() -> None: + call_count = 0 + + async def mock_func() -> str: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RateLimitError("Rate limit exceeded") + return "success after custom retry" + + # Custom handler with single retry + def custom_handle_async( + func: Callable[[], Awaitable[Any]], + ) -> Callable[[], Awaitable[Any]]: + async def wrapper() -> Any: + try: + return await func() + except RateLimitError: + return await func() # Retry once + + return wrapper + + handler = Mock(spec=RateLimitHandler) + handler.handle_async = custom_handle_async + + result = await handler.handle_async(mock_func)() + assert result == "success after custom retry" + assert call_count == 2