diff --git a/examples/customize/build_graph/components/writers/neo4j_writer.py b/examples/customize/build_graph/components/writers/neo4j_writer.py index 60dbcea97..fb6b40046 100644 --- a/examples/customize/build_graph/components/writers/neo4j_writer.py +++ b/examples/customize/build_graph/components/writers/neo4j_writer.py @@ -1,19 +1,35 @@ +import asyncio + import neo4j from neo4j_graphrag.experimental.components.kg_writer import ( KGWriterModel, Neo4jWriter, ) -from neo4j_graphrag.experimental.components.types import Neo4jGraph +from neo4j_graphrag.experimental.components.types import Neo4jGraph, Neo4jNode -async def main(driver: neo4j.Driver, graph: Neo4jGraph) -> KGWriterModel: +async def run_writer(driver: neo4j.Driver, graph: Neo4jGraph) -> KGWriterModel: writer = Neo4jWriter( driver, # optionally, configure the neo4j database # neo4j_database="neo4j", - # you can tune batch_size to - # improve speed + # you can tune batch_size to improve speed # batch_size=1000, ) result = await writer.run(graph=graph) return result + + +async def main(): + graph = Neo4jGraph( + nodes=[Neo4jNode(id="1", label="Label", properties={"name": "test"})] + ) + with neo4j.GraphDatabase.driver( + "bolt://localhost:7687", + auth=("neo4j", "password"), + ) as driver: + await run_writer(driver=driver, graph=graph) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/customize/llms/api_feedback_rate_limiting.py b/examples/customize/llms/api_feedback_rate_limiting.py new file mode 100644 index 000000000..0519ecba6 --- /dev/null +++ b/examples/customize/llms/api_feedback_rate_limiting.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/examples/customize/llms/precise_token_counting.py b/examples/customize/llms/precise_token_counting.py new file mode 100644 index 000000000..9cb1befc9 --- /dev/null +++ b/examples/customize/llms/precise_token_counting.py @@ -0,0 +1,316 @@ +""" +Precise Token Counting with OpenAI + +This example demonstrates how the OpenAI LLM implementation uses tiktoken +to count exact input tokens for precise rate limiting, rather than relying +on rough estimates. + +This makes rate limiting much more accurate and efficient. +""" + +import asyncio +import os +from typing import Dict, Any + +from neo4j_graphrag.llm import ( + OpenAILLM, + TokenBucketRateLimiter, + CompositeRateLimiter, + SlotBucketRateLimiter, + APIFeedbackRateLimiter, + RetryConfig, +) +from neo4j_graphrag.tool import Tool + + +def demonstrate_precise_token_counting(): + """Show how OpenAI uses precise token counting for rate limiting.""" + + print("=== Precise Token Counting with OpenAI ===") + + # Create a token-based rate limiter + token_limiter = TokenBucketRateLimiter( + tokens_per_second=500.0, # 30k tokens per minute + max_tokens=2000, # Burst capacity + ) + + # Create OpenAI LLM with token-based rate limiting + llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=token_limiter, + ) + + print(f"Created OpenAI LLM with precise token counting") + print(f"Tiktoken available: {hasattr(llm, '_tokenizer') and llm._tokenizer is not None}") + + # Test with different input sizes + test_inputs = [ + "Hello world!", # Short input + "Explain quantum computing in simple terms with examples and applications.", # Medium input + "Write a detailed explanation of machine learning, including supervised learning, unsupervised learning, reinforcement learning, neural networks, deep learning, and provide examples of each with real-world applications. Include the mathematical foundations and key algorithms." * 2, # Long input + ] + + for i, input_text in enumerate(test_inputs, 1): + print(f"\\nTest {i}: Input length = {len(input_text)} characters") + + # Show precise token count if available + if hasattr(llm, '_estimate_input_tokens'): + estimated_tokens = llm._estimate_input_tokens(input_text) + print(f"Precise token estimate: {estimated_tokens} tokens") + + # Check rate limiter status before request + status_before = token_limiter.get_status() + print(f"Tokens available before: {status_before.get('available_tokens', 'N/A')}") + + try: + response = llm.invoke(input_text) + print(f"Response: {response.content[:100]}...") + + # Check status after request + status_after = token_limiter.get_status() + print(f"Tokens available after: {status_after.get('available_tokens', 'N/A')}") + + except Exception as e: + print(f"Error: {e}") + + +def compare_estimation_vs_actual(): + """Compare estimated tokens vs actual tokens used.""" + + print("\\n=== Estimation vs Actual Token Usage ===") + + # Create rate limiter that tracks both estimates and actual usage + token_limiter = TokenBucketRateLimiter( + tokens_per_second=1000.0, + max_tokens=5000, + ) + + llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=token_limiter, + ) + + test_cases = [ + ("Simple question", "What is 2+2?"), + ("Complex question", "Explain the differences between machine learning and artificial intelligence, including their applications and limitations."), + ("With context", "Based on the previous discussion about AI, what are the ethical implications?"), + ] + + for name, input_text in test_cases: + print(f"\\n{name}:") + + # Get precise estimate + if hasattr(llm, '_estimate_input_tokens'): + estimated = llm._estimate_input_tokens(input_text) + print(f" Estimated input tokens: {estimated}") + + try: + response = llm.invoke(input_text) + print(f" Response: {response.content[:80]}...") + + # Note: Actual token usage would be updated in the rate limiter + # from the API response if using a TokenTrackingRateLimiter + + except Exception as e: + print(f" Error: {e}") + + +def demonstrate_tool_calling_token_counting(): + """Show token counting with tool calling (includes tool definitions).""" + + print("\\n=== Token Counting with Tool Calling ===") + + # Create a simple tool for demonstration + class WeatherTool(Tool): + def get_name(self) -> str: + return "get_weather" + + def get_description(self) -> str: + return "Get current weather information for a location" + + def get_parameters(self) -> Dict[str, Any]: + return { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit" + } + }, + "required": ["location"] + } + + def invoke(self, **kwargs) -> str: + return f"Weather in {kwargs.get('location', 'unknown')}: 22°C, sunny" + + # Create rate limiter + composite_limiter = CompositeRateLimiter([ + SlotBucketRateLimiter(slots_per_second=2.0, max_slots=5), # Request limiting + TokenBucketRateLimiter(tokens_per_second=800.0, max_tokens=3000), # Token limiting + ]) + + llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=composite_limiter, + ) + + tools = [WeatherTool()] + + print("Testing token counting with tool definitions...") + + # Estimate tokens including tool definitions + if hasattr(llm, '_estimate_input_tokens'): + input_text = "What's the weather like in San Francisco?" + + # Tokens without tools + tokens_without_tools = llm._estimate_input_tokens(input_text) + print(f"Tokens without tools: {tokens_without_tools}") + + # Tokens with tools (includes tool definitions) + tokens_with_tools = llm._estimate_input_tokens(input_text, tools=tools) + print(f"Tokens with tools: {tokens_with_tools}") + + tool_overhead = tokens_with_tools - tokens_without_tools + print(f"Tool definition overhead: {tool_overhead} tokens") + + try: + response = llm.invoke_with_tools( + "What's the weather like in San Francisco?", + tools=tools + ) + print(f"Tool response: {response}") + + except Exception as e: + print(f"Error: {e}") + + +def demonstrate_api_feedback_with_precise_counting(): + """Show combination of precise counting + API feedback for ultimate accuracy.""" + + print("\\n=== API Feedback + Precise Counting ===") + + # Use API feedback rate limiter for real-time sync with OpenAI headers + api_feedback_limiter = APIFeedbackRateLimiter( + fallback_requests_per_second=8.0, + fallback_tokens_per_second=500.0, + estimated_tokens_per_request=400, # This will be overridden by precise counting + ) + + llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=api_feedback_limiter, + ) + + print("Using API feedback limiter with precise token counting") + print("This provides the most accurate rate limiting possible!") + + # Test with a few requests + test_prompts = [ + "Explain photosynthesis briefly.", + "What are the main principles of quantum mechanics?", + "Describe the process of machine learning model training.", + ] + + for i, prompt in enumerate(test_prompts, 1): + print(f"\\nRequest {i}: {prompt[:50]}...") + + # Show precise estimate + if hasattr(llm, '_estimate_input_tokens'): + estimated = llm._estimate_input_tokens(prompt) + print(f" Precise estimate: {estimated} tokens") + + # Show rate limiter status + status = api_feedback_limiter.get_status() + print(f" API feedback available: {status.get('has_fresh_feedback', False)}") + + try: + response = llm.invoke(prompt) + print(f" Response: {response.content[:80]}...") + + # Show updated status after API response + updated_status = api_feedback_limiter.get_status() + print(f" Updated feedback: {updated_status.get('has_fresh_feedback', False)}") + + except Exception as e: + print(f" Error: {e}") + + +async def async_precise_token_counting(): + """Demonstrate async precise token counting.""" + + print("\\n=== Async Precise Token Counting ===") + + token_limiter = TokenBucketRateLimiter( + tokens_per_second=600.0, + max_tokens=2000, + ) + + llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=token_limiter, + ) + + # Create multiple concurrent requests with different token requirements + tasks = [] + prompts = [ + "Short question?", # Low tokens + "Medium length question about artificial intelligence and its applications?", # Medium tokens + "Very long and detailed question about the implications of machine learning, artificial intelligence, deep learning, neural networks, and their impact on society, economy, and future technological development?", # High tokens + ] + + print("Creating concurrent requests with different token requirements...") + + for i, prompt in enumerate(prompts): + if hasattr(llm, '_estimate_input_tokens'): + estimated = llm._estimate_input_tokens(prompt) + print(f"Task {i+1} estimated tokens: {estimated}") + + task = asyncio.create_task(llm.ainvoke(prompt)) + tasks.append(task) + + # Wait for all requests + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + print(f"Task {i+1} failed: {result}") + else: + print(f"Task {i+1} succeeded: {result.content[:50]}...") + + print(f"Final rate limiter status: {token_limiter.get_status()}") + + +def main(): + """Run all precise token counting examples.""" + + # Check for OpenAI API key + if not os.getenv("OPENAI_API_KEY"): + print("Warning: OPENAI_API_KEY not set. Some examples may not work.") + print("Set it with: export OPENAI_API_KEY=your-api-key") + return + + # Run sync examples + demonstrate_precise_token_counting() + compare_estimation_vs_actual() + demonstrate_tool_calling_token_counting() + demonstrate_api_feedback_with_precise_counting() + + # Run async example + asyncio.run(async_precise_token_counting()) + + print("\\n=== All Precise Token Counting Examples Complete! ===") + print("\\nKey Benefits:") + print("✅ Exact token counting with tiktoken") + print("✅ No more rough estimates or guessing") + print("✅ Efficient rate limiting - no wasted quota") + print("✅ Works with tools, context, and system instructions") + print("✅ Combines with API feedback for ultimate accuracy") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/customize/llms/rate_limiting_example.py b/examples/customize/llms/rate_limiting_example.py new file mode 100644 index 000000000..968c8e2b2 --- /dev/null +++ b/examples/customize/llms/rate_limiting_example.py @@ -0,0 +1,894 @@ +""" +Example demonstrating how to use rate limiting with neo4j-graphrag LLM providers. + +This example shows different ways to configure rate limiting for LLM providers +to handle API rate limits and avoid hitting provider constraints. + +Rate limiting is automatically applied to all LLM methods (invoke, ainvoke, +invoke_with_tools, ainvoke_with_tools) when a rate_limiter is provided. +""" + +import asyncio +import os +from typing import List + +from neo4j_graphrag.llm import ( + OpenAILLM, + AnthropicLLM, + SlotBucketRateLimiter, + TokenBucketRateLimiter, + CompositeRateLimiter, + APIFeedbackRateLimiter, + RetryConfig, +) +from neo4j_graphrag.experimental.components.entity_relation_extractor import ( + LLMEntityRelationExtractor, +) +from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks + + +def basic_rate_limiting_example(): + """Basic example of using rate limiting with OpenAI. + + Rate limiting is automatically applied to all LLM methods when configured. + """ + + # Create a rate limiter that allows 2 requests per second + rate_limiter = SlotBucketRateLimiter( + requests_per_second=2.0, + max_bucket_size=5, # Allow bursts up to 5 requests + ) + + # Create retry configuration + retry_config = RetryConfig( + max_retries=3, + base_delay=1.0, + max_delay=30.0, + exponential_base=2.0, + jitter=True, + ) + + # Create LLM with rate limiting - it's automatically applied to all methods + llm = OpenAILLM( + model_name="gpt-3.5-turbo", + model_params={"temperature": 0.0}, + rate_limiter=rate_limiter, # Automatic rate limiting! + retry_config=retry_config, + ) + + # Use the LLM normally - rate limiting happens automatically + response = llm.invoke("Hello, how are you?") + print(f"Response: {response.content}") + + +def provider_specific_rate_limiting(): + """Example using different rate limits for different providers. + + All providers automatically get rate limiting when configured. + """ + + # OpenAI with conservative rate limiting (3 requests/second) + openai_rate_limiter = SlotBucketRateLimiter( + requests_per_second=3.0, + max_bucket_size=10, + ) + openai_llm = OpenAILLM( + model_name="gpt-4", + model_params={"temperature": 0.0}, + rate_limiter=openai_rate_limiter, # Automatic for all methods + ) + + # Anthropic with conservative rate limiting (2 requests/second) + anthropic_rate_limiter = SlotBucketRateLimiter( + requests_per_second=2.0, + max_bucket_size=5, + ) + anthropic_llm = AnthropicLLM( + model_name="claude-3-haiku-20240307", + model_params={"temperature": 0.0}, + rate_limiter=anthropic_rate_limiter, # Automatic for all methods + ) + + # Very conservative rate limiter for heavy workloads + custom_openai_rate_limiter = SlotBucketRateLimiter( + requests_per_second=1.0, + max_bucket_size=3, + ) + + custom_openai_llm = OpenAILLM( + model_name="gpt-4", + rate_limiter=custom_openai_rate_limiter, # Automatic for all methods + ) + + return openai_llm, anthropic_llm, custom_openai_llm + + +def shared_rate_limiter_example(): + """Example of sharing rate limiter between multiple LLM instances. + + Rate limiting is applied automatically to all methods for all instances. + """ + + # Create ONE rate limiter for OpenAI (since they share the same API key/limits) + shared_openai_limiter = SlotBucketRateLimiter( + requests_per_second=2.0, + max_bucket_size=5, + ) + + # Multiple LLMs sharing the same rate limiter + # Rate limiting is automatic for all methods on both instances + summarizer = OpenAILLM( + model_name="gpt-3.5-turbo", + model_params={"temperature": 0.0}, + rate_limiter=shared_openai_limiter, # Shared, automatic rate limiting! + ) + + reasoner = OpenAILLM( + model_name="gpt-4", + model_params={"temperature": 0.1}, + rate_limiter=shared_openai_limiter, # Same instance, automatic! + ) + + return summarizer, reasoner + + +async def rate_limiting_with_entity_extraction(): + """Example of using rate limiting with EntityRelationExtractor. + + Rate limiting is automatically applied to all LLM methods. + """ + + # Create rate limiter for OpenAI + rate_limiter = SlotBucketRateLimiter( + requests_per_second=2.0, # Conservative rate for entity extraction + max_bucket_size=5, + ) + + # Create LLM with rate limiting - automatically applied to all methods + llm = OpenAILLM( + model_name="gpt-4", + model_params={ + "temperature": 0.0, + "response_format": {"type": "json_object"}, + }, + rate_limiter=rate_limiter, # Automatic rate limiting! + ) + + # Create entity extractor + extractor = LLMEntityRelationExtractor( + llm=llm, + max_concurrency=3, # Process 3 chunks concurrently + # The rate limiter will automatically ensure we don't exceed our limits + # even with concurrent processing - no manual intervention needed! + ) + + # Prepare some sample text chunks + sample_texts = [ + "Alice works at TechCorp as a software engineer.", + "Bob is the CEO of StartupInc based in Silicon Valley.", + "Charlie and Diana are researchers at MIT studying AI.", + "Eve leads the marketing team at GlobalBrand.", + "Frank is a doctor at City Hospital.", + ] + + chunks = TextChunks( + chunks=[ + TextChunk(text=text, index=i) + for i, text in enumerate(sample_texts) + ] + ) + + # Run extraction - rate limiting is applied automatically + print("Running entity extraction with automatic rate limiting...") + graph = await extractor.run(chunks=chunks) + + print(f"Extracted {len(graph.nodes)} nodes and {len(graph.relationships)} relationships") + for node in graph.nodes[:3]: # Show first 3 nodes + print(f"Node: {node.label} - {node.properties}") + + +def custom_rate_limiting_configuration(): + """Example of custom rate limiting configurations for different scenarios. + + All configurations automatically apply to all LLM methods. + """ + + # High-frequency, light requests (e.g., simple questions) + high_frequency_limiter = SlotBucketRateLimiter( + requests_per_second=5.0, + max_bucket_size=10, + ) + + # Low-frequency, heavy requests (e.g., document processing) + low_frequency_limiter = SlotBucketRateLimiter( + requests_per_second=0.5, # 1 request every 2 seconds + max_bucket_size=2, + ) + + # Burst-friendly configuration + burst_friendly_limiter = SlotBucketRateLimiter( + requests_per_second=2.0, + max_bucket_size=20, # Allow large bursts + ) + + # Aggressive retry configuration for critical operations + aggressive_retry_config = RetryConfig( + max_retries=5, + base_delay=2.0, + max_delay=120.0, # Up to 2 minutes + exponential_base=2.0, + jitter=True, + ) + + # Conservative retry configuration for non-critical operations + conservative_retry_config = RetryConfig( + max_retries=2, + base_delay=0.5, + max_delay=10.0, + exponential_base=1.5, + jitter=False, + ) + + # Create LLMs for different use cases + # Rate limiting is automatically applied to all methods + light_requests_llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=high_frequency_limiter, # Automatic! + retry_config=conservative_retry_config, + ) + + heavy_requests_llm = OpenAILLM( + model_name="gpt-4", + rate_limiter=low_frequency_limiter, # Automatic! + retry_config=aggressive_retry_config, + ) + + burst_llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=burst_friendly_limiter, # Automatic! + ) + + return light_requests_llm, heavy_requests_llm, burst_llm + + +def token_based_rate_limiting_example(): + """Example using token-based rate limiting for more accurate control. + + Token-based rate limiting tracks actual token consumption from API responses, + providing much more accurate rate limiting than simple request counting. + """ + + # Create a token-based rate limiter + # OpenAI GPT-4: typically 40,000 TPM (tokens per minute) = ~667 tokens/second + token_rate_limiter = TokenBucketRateLimiter( + tokens_per_second=600.0, # Conservative: 600 tokens/second + max_bucket_size=36000, # 1 minute worth of tokens + estimated_tokens_per_request=500, # Conservative estimate for pre-flight + ) + + # Create LLM with token-based rate limiting + llm = OpenAILLM( + model_name="gpt-4", + model_params={"temperature": 0.0}, + rate_limiter=token_rate_limiter, # Automatic token tracking! + ) + + print("=== Token-Based Rate Limiting ===") + print("Using actual token consumption from API responses!") + + # Test with different sized requests + test_prompts = [ + "Hi", # Small: ~50 tokens + "Write a short story about a robot learning to paint.", # Medium: ~200-400 tokens + "Explain quantum computing, machine learning, and blockchain technology in detail with examples.", # Large: ~800-1200 tokens + ] + + for i, prompt in enumerate(test_prompts, 1): + print(f"\nRequest {i}: {prompt[:50]}...") + response = llm.invoke(prompt) + print(f"Response length: {len(response.content)} characters") + # The rate limiter automatically tracks actual token usage from the API response! + + +def comparing_slot_vs_token_rate_limiting(): + """Example comparing slot-based vs token-based rate limiting.""" + + print("=== Comparing Slot vs Token Rate Limiting ===") + + # Slot-based: treats all requests equally + slot_rate_limiter = SlotBucketRateLimiter( + requests_per_second=2.0, # 2 requests per second + max_bucket_size=5, + ) + + slot_llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=slot_rate_limiter, + ) + + # Token-based: tracks actual token consumption + token_rate_limiter = TokenBucketRateLimiter( + tokens_per_second=1000.0, # 1000 tokens per second + max_bucket_size=5000, # 5 second buffer + estimated_tokens_per_request=300, + ) + + token_llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=token_rate_limiter, + ) + + print("\nSlot-based limiter: Treats 'Hi' and long essays the same") + print("Token-based limiter: Accurately tracks actual token usage") + print("\nThis means:") + print("- You can make more small requests with token-based limiting") + print("- Large requests are properly throttled based on actual consumption") + print("- Better utilization of your API quotas!") + + +def advanced_token_rate_limiting(): + """Example of advanced token-based rate limiting configurations.""" + + # High-throughput configuration for GPT-3.5-turbo + high_throughput_limiter = TokenBucketRateLimiter( + tokens_per_second=3000.0, # 3000 tokens/second (180k/minute) + max_bucket_size=15000, # 5 seconds worth + estimated_tokens_per_request=200, # Optimistic estimate + ) + + # Conservative configuration for GPT-4 + conservative_limiter = TokenBucketRateLimiter( + tokens_per_second=500.0, # 500 tokens/second (30k/minute) + max_bucket_size=2500, # 5 seconds worth + estimated_tokens_per_request=800, # Conservative estimate + ) + + # Burst-friendly configuration + burst_friendly_limiter = TokenBucketRateLimiter( + tokens_per_second=1000.0, # 1000 tokens/second + max_bucket_size=60000, # 1 minute of tokens for bursts + estimated_tokens_per_request=400, + ) + + high_throughput_llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=high_throughput_limiter, + ) + + conservative_llm = OpenAILLM( + model_name="gpt-4", + rate_limiter=conservative_limiter, + ) + + burst_llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=burst_friendly_limiter, + ) + + print("=== Advanced Token Rate Limiting ===") + print("✓ High-throughput: 3000 tokens/second for GPT-3.5-turbo") + print("✓ Conservative: 500 tokens/second for GPT-4") + print("✓ Burst-friendly: Large token buffer for burst requests") + print("✓ All configurations automatically track actual token usage!") + + return high_throughput_llm, conservative_llm, burst_llm + + +def realistic_openai_rate_limiting(): + """Example using composite rate limiting that matches real OpenAI API limits. + + OpenAI has BOTH request-based AND token-based limits: + - Requests per minute (RPM): e.g., 500 requests/minute + - Tokens per minute (TPM): e.g., 30,000 tokens/minute + + Depending on usage, you could hit either limit first! + """ + + print("=== Realistic OpenAI Rate Limiting ===") + print("Enforcing BOTH request limits AND token limits simultaneously!") + + # Request-based limiter: 500 requests/minute = ~8.33 requests/second + request_limiter = SlotBucketRateLimiter( + requests_per_second=8.0, # Slightly conservative + max_bucket_size=20, # Allow small bursts + ) + + # Token-based limiter: 30,000 tokens/minute = 500 tokens/second + token_limiter = TokenBucketRateLimiter( + tokens_per_second=500.0, + max_bucket_size=2500, # 5 seconds worth + estimated_tokens_per_request=400, + ) + + # Composite limiter: request must pass BOTH limits + composite_limiter = CompositeRateLimiter( + request_limiter, # Must have request slots available + token_limiter, # Must have token budget available + ) + + # Create LLM with realistic rate limiting + llm = OpenAILLM( + model_name="gpt-4", + model_params={"temperature": 0.0}, + rate_limiter=composite_limiter, # Enforces BOTH limits! + ) + + print("\nThis setup will:") + print("✓ Limit to ~8 requests/second (RPM limit)") + print("✓ Limit to ~500 tokens/second (TPM limit)") + print("✓ Block requests if EITHER limit is exceeded") + print("✓ Allow more small requests when token budget permits") + print("✓ Throttle large requests even if request budget permits") + + # Test scenarios + print("\n=== Test Scenarios ===") + + # Many small requests → will hit RPM limit first + print("Scenario 1: Many small requests ('Hi' repeated)") + print("→ Expected: Will hit REQUEST limit before token limit") + + # Few large requests → will hit TPM limit first + print("\nScenario 2: Few large requests (complex analysis)") + print("→ Expected: Will hit TOKEN limit before request limit") + + return llm + + +def provider_specific_composite_limiting(): + """Examples of composite rate limiting for different providers.""" + + print("=== Provider-Specific Composite Rate Limiting ===") + + # OpenAI GPT-4: 500 RPM, 30k TPM + openai_request_limiter = SlotBucketRateLimiter(requests_per_second=8.0) + openai_token_limiter = TokenBucketRateLimiter( + tokens_per_second=500.0, + estimated_tokens_per_request=600 + ) + openai_composite = CompositeRateLimiter(openai_request_limiter, openai_token_limiter) + + openai_llm = OpenAILLM( + model_name="gpt-4", + rate_limiter=openai_composite, + ) + + # OpenAI GPT-3.5-turbo: 3500 RPM, 90k TPM (higher limits) + openai_turbo_request_limiter = SlotBucketRateLimiter(requests_per_second=50.0) + openai_turbo_token_limiter = TokenBucketRateLimiter( + tokens_per_second=1500.0, + estimated_tokens_per_request=300 + ) + openai_turbo_composite = CompositeRateLimiter( + openai_turbo_request_limiter, + openai_turbo_token_limiter + ) + + openai_turbo_llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=openai_turbo_composite, + ) + + # Conservative setup for production workloads + conservative_request_limiter = SlotBucketRateLimiter(requests_per_second=2.0) + conservative_token_limiter = TokenBucketRateLimiter( + tokens_per_second=200.0, + estimated_tokens_per_request=800 + ) + conservative_composite = CompositeRateLimiter( + conservative_request_limiter, + conservative_token_limiter + ) + + conservative_llm = OpenAILLM( + model_name="gpt-4", + rate_limiter=conservative_composite, + ) + + print("✅ GPT-4: 8 req/s + 500 tokens/s") + print("✅ GPT-3.5-turbo: 50 req/s + 1500 tokens/s") + print("✅ Conservative: 2 req/s + 200 tokens/s") + print("✅ All enforce BOTH request AND token limits!") + + return openai_llm, openai_turbo_llm, conservative_llm + + +def api_feedback_rate_limiting_example(): + """Example using API feedback rate limiting that syncs with OpenAI's real-time headers. + + This is the most accurate rate limiting approach because it uses actual API feedback + instead of guessing or estimating limits. + """ + + print("=== API Feedback Rate Limiting ===") + print("Syncing with real-time OpenAI API headers!") + + # Create API feedback rate limiter + # It will sync with actual OpenAI rate limit headers for perfect accuracy + api_feedback_limiter = APIFeedbackRateLimiter( + fallback_requests_per_second=8.0, # Used when no API feedback available + fallback_tokens_per_second=500.0, # Used when no API feedback available + estimated_tokens_per_request=400, # For pre-flight estimation + ) + + # Create LLM with API feedback rate limiting + llm = OpenAILLM( + model_name="gpt-4", + model_params={"temperature": 0.0}, + rate_limiter=api_feedback_limiter, # Real-time API sync! + ) + + print("\nBenefits of API Feedback Rate Limiting:") + print("✓ Syncs with actual API limits in real-time") + print("✓ Knows exactly how many requests/tokens are remaining") + print("✓ Knows exactly when limits reset") + print("✓ No guesswork or estimation errors") + print("✓ Maximum API quota utilization") + + # Make a test request + print("\nMaking test request...") + response = llm.invoke("What is the capital of France?") + print(f"Response: {response.content[:100]}...") + + # Check rate limiter status (debugging info) + status = api_feedback_limiter.get_status() + print(f"\nRate Limiter Status:") + print(f"Remaining requests: {status['remaining_requests']}") + print(f"Remaining tokens: {status['remaining_tokens']}") + print(f"Request reset in: {status['request_reset_in_seconds']:.1f}s" if status['request_reset_in_seconds'] else "Unknown") + print(f"Token reset in: {status['token_reset_in_seconds']:.1f}s" if status['token_reset_in_seconds'] else "Unknown") + print(f"Request limit: {status['limit_requests']}") + print(f"Token limit: {status['limit_tokens']}") + print(f"Has fresh feedback: {status['has_fresh_feedback']}") + + return llm + + +def comparing_rate_limiting_approaches(): + """Compare different rate limiting approaches to show the benefits of API feedback.""" + + print("=== Comparing Rate Limiting Approaches ===") + + # 1. Basic slot-based (treats all requests equally) + slot_limiter = SlotBucketRateLimiter(requests_per_second=2.0) + slot_llm = OpenAILLM(model_name="gpt-3.5-turbo", rate_limiter=slot_limiter) + + # 2. Token-based (tracks actual usage) + token_limiter = TokenBucketRateLimiter( + tokens_per_second=1000.0, + estimated_tokens_per_request=400 + ) + token_llm = OpenAILLM(model_name="gpt-3.5-turbo", rate_limiter=token_limiter) + + # 3. API feedback (syncs with real API limits) + api_feedback_limiter = APIFeedbackRateLimiter( + fallback_requests_per_second=8.0, + fallback_tokens_per_second=1000.0, + estimated_tokens_per_request=400 + ) + api_feedback_llm = OpenAILLM(model_name="gpt-3.5-turbo", rate_limiter=api_feedback_limiter) + + print("\n🔹 Slot-based:") + print(" → All requests treated equally") + print(" → Simple but inefficient") + print(" → May waste or exceed quotas") + + print("\n🔸 Token-based:") + print(" → Tracks actual token usage") + print(" → Better quota utilization") + print(" → Still based on estimates") + + print("\n🔶 API Feedback:") + print(" → Syncs with real API limits") + print(" → Perfect accuracy") + print(" → Maximum quota utilization") + print(" → Future-proof (adapts to API changes)") + + return slot_llm, token_llm, api_feedback_llm + + +def advanced_api_feedback_scenarios(): + """Advanced scenarios showing API feedback rate limiting capabilities.""" + + print("=== Advanced API Feedback Scenarios ===") + + # Scenario 1: High-throughput with API feedback + high_throughput_limiter = APIFeedbackRateLimiter( + fallback_requests_per_second=50.0, # High fallback rate + fallback_tokens_per_second=3000.0, # High fallback token rate + estimated_tokens_per_request=200, # Optimistic estimate + ) + + high_throughput_llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=high_throughput_limiter, + ) + + # Scenario 2: Conservative with API feedback + conservative_limiter = APIFeedbackRateLimiter( + fallback_requests_per_second=2.0, # Conservative fallback + fallback_tokens_per_second=300.0, # Conservative fallback + estimated_tokens_per_request=800, # Conservative estimate + ) + + conservative_llm = OpenAILLM( + model_name="gpt-4", + rate_limiter=conservative_limiter, + ) + + # Scenario 3: Multiple models sharing API feedback + # Since they share the same API key, they should share rate limits + shared_api_feedback_limiter = APIFeedbackRateLimiter( + fallback_requests_per_second=8.0, + fallback_tokens_per_second=500.0, + estimated_tokens_per_request=500, + ) + + summarizer = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=shared_api_feedback_limiter, + ) + + analyzer = OpenAILLM( + model_name="gpt-4", + rate_limiter=shared_api_feedback_limiter, # Same instance! + ) + + print("✅ High-throughput: Optimistic fallbacks, syncs with real limits") + print("✅ Conservative: Safe fallbacks, syncs with real limits") + print("✅ Shared limits: Multiple models coordinate via API feedback") + + return high_throughput_llm, conservative_llm, summarizer, analyzer + + +def api_feedback_debugging_example(): + """Example showing how to debug and monitor API feedback rate limiting.""" + + print("=== API Feedback Debugging ===") + + # Create rate limiter with debugging capabilities + debug_limiter = APIFeedbackRateLimiter( + fallback_requests_per_second=5.0, + fallback_tokens_per_second=1000.0, + estimated_tokens_per_request=300, + ) + + llm = OpenAILLM( + model_name="gpt-3.5-turbo", + rate_limiter=debug_limiter, + ) + + print("Debugging capabilities:") + print("✓ View remaining requests/tokens in real-time") + print("✓ See when limits reset") + print("✓ Monitor API feedback freshness") + print("✓ Track estimation accuracy") + + # Simulate some requests and show debugging info + test_prompts = [ + "Hi there!", + "Explain quantum computing briefly.", + "Write a short poem about spring.", + ] + + for i, prompt in enumerate(test_prompts, 1): + print(f"\n--- Request {i}: {prompt[:30]}... ---") + + # Show status before request + status_before = debug_limiter.get_status() + print(f"Before: {status_before['remaining_requests'] or 'Unknown'} requests, " + f"{status_before['remaining_tokens'] or 'Unknown'} tokens remaining") + + # Make request (rate limiter gets updated automatically) + try: + response = llm.invoke(prompt) + print(f"✓ Success: {len(response.content)} chars") + except Exception as e: + print(f"✗ Error: {e}") + + # Show status after request + status_after = debug_limiter.get_status() + print(f"After: {status_after['remaining_requests'] or 'Unknown'} requests, " + f"{status_after['remaining_tokens'] or 'Unknown'} tokens remaining") + + if status_after['has_fresh_feedback']: + print("📡 Fresh API feedback available") + else: + print("⏳ Using fallback rate limiting") + + # Final status + final_status = debug_limiter.get_status() + print(f"\n=== Final Status ===") + for key, value in final_status.items(): + if key.endswith('_seconds') and value is not None: + print(f"{key}: {value:.1f}s") + else: + print(f"{key}: {value}") + + return llm, debug_limiter + + +async def api_feedback_with_entity_extraction(): + """Example of API feedback rate limiting with entity extraction.""" + + print("=== API Feedback + Entity Extraction ===") + + # Create API feedback rate limiter + api_feedback_limiter = APIFeedbackRateLimiter( + fallback_requests_per_second=3.0, # Conservative for entity extraction + fallback_tokens_per_second=600.0, + estimated_tokens_per_request=600, # Entity extraction uses more tokens + ) + + # Create LLM with API feedback rate limiting + llm = OpenAILLM( + model_name="gpt-4", + model_params={ + "temperature": 0.0, + "response_format": {"type": "json_object"}, + }, + rate_limiter=api_feedback_limiter, + ) + + # Create entity extractor + extractor = LLMEntityRelationExtractor( + llm=llm, + max_concurrency=2, # Conservative concurrency + ) + + # Sample texts + sample_texts = [ + "Dr. Alice Johnson works at TechCorp's research division in Boston.", + "Bob Chen, the CEO of StartupInc, announced a partnership with GlobalTech.", + "Professor Diana Martinez from MIT published research on quantum computing." + ] + + chunks = TextChunks( + chunks=[ + TextChunk(text=text, index=i) + for i, text in enumerate(sample_texts) + ] + ) + + print("Running entity extraction with API feedback rate limiting...") + + # Check status before + status_before = api_feedback_limiter.get_status() + print(f"Before extraction: {status_before['remaining_requests'] or 'Unknown'} requests available") + + # Run extraction (rate limiting happens automatically) + graph = await extractor.run(chunks=chunks) + + # Check status after + status_after = api_feedback_limiter.get_status() + print(f"After extraction: {status_after['remaining_requests'] or 'Unknown'} requests available") + + print(f"Extracted {len(graph.nodes)} nodes and {len(graph.relationships)} relationships") + print("✓ Rate limiting was applied automatically and accurately!") + + return graph, api_feedback_limiter + + +async def main(): + """Main function demonstrating different rate limiting approaches.""" + + print("=== Basic Rate Limiting Example ===") + print("Rate limiting is automatically applied to all LLM methods!") + try: + basic_rate_limiting_example() + except Exception as e: + print(f"Error in basic example: {e}") + + print("\n=== Provider-Specific Rate Limiting ===") + print("All providers get automatic rate limiting when configured!") + try: + openai_llm, anthropic_llm, custom_llm = provider_specific_rate_limiting() + print("Successfully created rate-limited LLMs for different providers") + except Exception as e: + print(f"Error in provider-specific example: {e}") + + print("\n=== Shared Rate Limiter Example ===") + print("Multiple LLMs can share the same rate limiter automatically!") + try: + summarizer, reasoner = shared_rate_limiter_example() + print("Successfully created LLMs with shared rate limiter") + except Exception as e: + print(f"Error in shared rate limiter example: {e}") + + print("\n=== Rate Limiting with Entity Extraction ===") + print("Rate limiting works automatically with all neo4j-graphrag components!") + try: + await rate_limiting_with_entity_extraction() + except Exception as e: + print(f"Error in entity extraction example: {e}") + + print("\n=== Custom Rate Limiting Configurations ===") + print("Different rate limiting strategies are automatically applied!") + try: + light_llm, heavy_llm, burst_llm = custom_rate_limiting_configuration() + print("Successfully created custom rate-limited LLM configurations") + except Exception as e: + print(f"Error in custom configuration example: {e}") + + print("\n=== Token-Based Rate Limiting ===") + print("Using actual token consumption from API responses!") + try: + token_based_rate_limiting_example() + except Exception as e: + print(f"Error in token-based rate limiting example: {e}") + + print("\n=== Comparing Slot vs Token Rate Limiting ===") + try: + comparing_slot_vs_token_rate_limiting() + except Exception as e: + print(f"Error in comparing slot vs token rate limiting example: {e}") + + print("\n=== Advanced Token Rate Limiting ===") + try: + high_throughput_llm, conservative_llm, burst_llm = advanced_token_rate_limiting() + print("Successfully created advanced token-based rate-limited LLM configurations") + except Exception as e: + print(f"Error in advanced token rate limiting example: {e}") + + print("\n=== Realistic OpenAI Rate Limiting ===") + try: + llm = realistic_openai_rate_limiting() + print("Successfully created realistic rate-limited LLM") + except Exception as e: + print(f"Error in realistic OpenAI rate limiting example: {e}") + + print("\n=== Provider-Specific Composite Rate Limiting ===") + try: + openai_llm, openai_turbo_llm, conservative_llm = provider_specific_composite_limiting() + print("Successfully created composite rate-limited LLM configurations") + except Exception as e: + print(f"Error in provider-specific composite rate limiting example: {e}") + + try: + llm = api_feedback_rate_limiting_example() + print("Successfully created API feedback rate-limited LLM") + except Exception as e: + print(f"Error in API feedback rate limiting example: {e}") + + print("\n=== Comparing Rate Limiting Approaches ===") + try: + slot_llm, token_llm, api_feedback_llm = comparing_rate_limiting_approaches() + print("Successfully created rate-limited LLM configurations") + except Exception as e: + print(f"Error in comparing rate limiting approaches example: {e}") + + print("\n=== Advanced API Feedback Scenarios ===") + try: + high_throughput_llm, conservative_llm, summarizer, analyzer = advanced_api_feedback_scenarios() + print("Successfully created advanced API feedback rate-limited LLM configurations") + except Exception as e: + print(f"Error in advanced API feedback scenarios example: {e}") + + print("\n=== API Feedback Debugging ===") + try: + llm, debug_limiter = api_feedback_debugging_example() + print("Successfully created API feedback rate-limited LLM") + except Exception as e: + print(f"Error in API feedback debugging example: {e}") + + try: + graph, api_feedback_limiter = api_feedback_with_entity_extraction() + print("Successfully created API feedback rate-limited LLM") + except Exception as e: + print(f"Error in API feedback with entity extraction example: {e}") + + print("\n=== Summary ===") + print("✓ Rate limiting is now AUTOMATIC for all LLM providers!") + print("✓ No need to manually apply rate limiting in implementations") + print("✓ Works with invoke, ainvoke, invoke_with_tools, ainvoke_with_tools") + print("✓ All existing and future LLM providers get rate limiting for free") + + +if __name__ == "__main__": + # Note: Make sure to set your API keys: + # export OPENAI_API_KEY="your-openai-api-key" + # export ANTHROPIC_API_KEY="your-anthropic-api-key" + + if not os.getenv("OPENAI_API_KEY"): + print("Warning: OPENAI_API_KEY not set. Some examples may fail.") + + asyncio.run(main()) \ No newline at end of file diff --git a/src/neo4j_graphrag/llm/__init__.py b/src/neo4j_graphrag/llm/__init__.py index a9ece5ccb..2f5b53400 100644 --- a/src/neo4j_graphrag/llm/__init__.py +++ b/src/neo4j_graphrag/llm/__init__.py @@ -20,15 +20,32 @@ from .openai_llm import AzureOpenAILLM, OpenAILLM from .types import LLMResponse from .vertexai_llm import VertexAILLM +from .rate_limiter import ( + BaseRateLimiter, + SlotBucketRateLimiter, + TokenBucketRateLimiter, + TokenTrackingRateLimiter, + CompositeRateLimiter, + APIFeedbackRateLimiter, + RetryConfig, +) __all__ = [ "AnthropicLLM", + "AzureOpenAILLM", "CohereLLM", - "LLMResponse", "LLMInterface", + "LLMResponse", + "MistralAILLM", "OllamaLLM", "OpenAILLM", "VertexAILLM", - "AzureOpenAILLM", - "MistralAILLM", + # Rate limiting exports + "BaseRateLimiter", + "SlotBucketRateLimiter", + "TokenBucketRateLimiter", + "TokenTrackingRateLimiter", + "CompositeRateLimiter", + "APIFeedbackRateLimiter", + "RetryConfig", ] diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 87d281794..c30497ded 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -14,16 +14,27 @@ # limitations under the License. from __future__ import annotations +import asyncio +import logging +import time from abc import ABC, abstractmethod -from typing import Any, List, Optional, Sequence, Union +from functools import wraps +from typing import Any, Callable, List, Optional, Sequence, TypeVar, Union from neo4j_graphrag.message_history import MessageHistory from neo4j_graphrag.types import LLMMessage from .types import LLMResponse, ToolCallResponse +from .rate_limiter import BaseRateLimiter, RetryConfig, is_rate_limit_error +from ..exceptions import LLMGenerationError from neo4j_graphrag.tool import Tool +logger = logging.getLogger(__name__) + +# Type variable for function return types +T = TypeVar("T") + class LLMInterface(ABC): """Interface for large language models. @@ -31,6 +42,8 @@ class LLMInterface(ABC): Args: model_name (str): The name of the language model. model_params (Optional[dict]): Additional parameters passed to the model when text is sent to it. Defaults to None. + rate_limiter (Optional[BaseRateLimiter]): Rate limiter to control request frequency. Defaults to None. + retry_config (Optional[RetryConfig]): Configuration for retry behavior on rate limit errors. Defaults to None. **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. """ @@ -38,12 +51,242 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limiter: Optional[BaseRateLimiter] = None, + retry_config: Optional[RetryConfig] = None, **kwargs: Any, ): self.model_name = model_name self.model_params = model_params or {} + + # Rate limiting setup + self.rate_limiter = rate_limiter + self.retry_config = retry_config or RetryConfig() + self._llm_name = f"{self.__class__.__name__}({model_name})" + + def _apply_rate_limiting(self, func: Callable[..., T]) -> Callable[..., T]: + """Apply rate limiting to a function if rate limiter is configured.""" + if not self.rate_limiter: + return func + + @wraps(func) + def wrapper(*args, **kwargs) -> T: + last_error = None + + for attempt in range(self.retry_config.max_retries + 1): + # Get precise token estimate if implementation supports it + estimated_tokens = None + if hasattr(self, '_estimate_input_tokens'): + try: + # Extract parameters based on function signature + if len(args) >= 1: # input is first positional arg + input_text = args[0] + message_history = args[1] if len(args) >= 2 else kwargs.get('message_history') + system_instruction = args[2] if len(args) >= 3 else kwargs.get('system_instruction') + tools = args[3] if len(args) >= 4 else kwargs.get('tools') + + estimated_tokens = self._estimate_input_tokens( + input_text, message_history, system_instruction, tools + ) + except Exception: + # If token estimation fails, continue without it + estimated_tokens = None + + # Wait for rate limiter with precise token count if available + if estimated_tokens is not None: + while not self.rate_limiter.acquire(estimated_tokens): + time.sleep(0.1) + else: + while not self.rate_limiter.acquire(): + time.sleep(0.1) + + try: + return func(*args, **kwargs) + except Exception as e: + last_error = e + + if is_rate_limit_error(e) and attempt < self.retry_config.max_retries: + delay = self.retry_config.get_delay(attempt) + logger.warning( + f"{self._llm_name}: Rate limit hit, retrying in {delay:.2f}s " + f"(attempt {attempt + 1}/{self.retry_config.max_retries + 1})" + ) + time.sleep(delay) + continue + else: + # Re-raise if not a rate limit error or max retries reached + raise + + # This shouldn't be reached, but just in case + raise last_error or LLMGenerationError("Max retries exceeded") + + return wrapper + + def _apply_async_rate_limiting(self, func: Callable[..., T]) -> Callable[..., T]: + """Apply rate limiting to an async function if rate limiter is configured.""" + if not self.rate_limiter: + return func + + @wraps(func) + async def wrapper(*args, **kwargs) -> T: + last_error = None + + for attempt in range(self.retry_config.max_retries + 1): + # Get precise token estimate if implementation supports it + estimated_tokens = None + if hasattr(self, '_estimate_input_tokens'): + try: + # Extract parameters based on function signature + if len(args) >= 1: # input is first positional arg + input_text = args[0] + message_history = args[1] if len(args) >= 2 else kwargs.get('message_history') + system_instruction = args[2] if len(args) >= 3 else kwargs.get('system_instruction') + tools = args[3] if len(args) >= 4 else kwargs.get('tools') + + estimated_tokens = self._estimate_input_tokens( + input_text, message_history, system_instruction, tools + ) + except Exception: + # If token estimation fails, continue without it + estimated_tokens = None + + # Wait for rate limiter with precise token count if available + if estimated_tokens is not None: + while not await self.rate_limiter.aacquire(estimated_tokens): + await asyncio.sleep(0.1) + else: + while not await self.rate_limiter.aacquire(): + await asyncio.sleep(0.1) + + try: + return await func(*args, **kwargs) + except Exception as e: + last_error = e + + if is_rate_limit_error(e) and attempt < self.retry_config.max_retries: + delay = self.retry_config.get_delay(attempt) + logger.warning( + f"{self._llm_name}: Rate limit hit, retrying in {delay:.2f}s " + f"(attempt {attempt + 1}/{self.retry_config.max_retries + 1})" + ) + await asyncio.sleep(delay) + continue + else: + # Re-raise if not a rate limit error or max retries reached + raise + + # This shouldn't be reached, but just in case + raise last_error or LLMGenerationError("Max retries exceeded") + + return wrapper + + # Abstract methods that implementations must override + @abstractmethod + def _invoke( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: + """Internal method to send a text input to the LLM and retrieve a response. + + This method should be overridden by concrete implementations. + Rate limiting will be applied automatically by the public invoke method. + + Args: + input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + LLMResponse: The response from the LLM. + + Raises: + LLMGenerationError: If anything goes wrong. + """ @abstractmethod + async def _ainvoke( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: + """Internal async method to send a text input to the LLM and retrieve a response. + + This method should be overridden by concrete implementations. + Rate limiting will be applied automatically by the public ainvoke method. + + Args: + input (str): Text sent to the LLM. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + LLMResponse: The response from the LLM. + + Raises: + LLMGenerationError: If anything goes wrong. + """ + + def _invoke_with_tools( + self, + input: str, + tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + """Internal method to send a text input to the LLM with tool definitions. + + This method can be overridden by concrete implementations that support tool calling. + Rate limiting will be applied automatically by the public invoke_with_tools method. + + Args: + input (str): Text sent to the LLM. + tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + ToolCallResponse: The response from the LLM containing a tool call. + + Raises: + LLMGenerationError: If anything goes wrong. + NotImplementedError: If the LLM provider does not support tool calling. + """ + raise NotImplementedError("This LLM provider does not support tool calling.") + + async def _ainvoke_with_tools( + self, + input: str, + tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + """Internal async method to send a text input to the LLM with tool definitions. + + This method can be overridden by concrete implementations that support tool calling. + Rate limiting will be applied automatically by the public ainvoke_with_tools method. + + Args: + input (str): Text sent to the LLM. + tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format. + message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, + with each message having a specific role assigned. + system_instruction (Optional[str]): An option to override the llm system message for this invocation. + + Returns: + ToolCallResponse: The response from the LLM containing a tool call. + + Raises: + LLMGenerationError: If anything goes wrong. + NotImplementedError: If the LLM provider does not support tool calling. + """ + raise NotImplementedError("This LLM provider does not support tool calling.") + + # Public methods that automatically apply rate limiting def invoke( self, input: str, @@ -51,6 +294,8 @@ def invoke( system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. + + Rate limiting is applied automatically if configured. Args: input (str): Text sent to the LLM. @@ -64,8 +309,9 @@ def invoke( Raises: LLMGenerationError: If anything goes wrong. """ + rate_limited_invoke = self._apply_rate_limiting(self._invoke) + return rate_limited_invoke(input, message_history, system_instruction) - @abstractmethod async def ainvoke( self, input: str, @@ -73,6 +319,8 @@ async def ainvoke( system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. + + Rate limiting is applied automatically if configured. Args: input (str): Text sent to the LLM. @@ -86,6 +334,8 @@ async def ainvoke( Raises: LLMGenerationError: If anything goes wrong. """ + rate_limited_ainvoke = self._apply_async_rate_limiting(self._ainvoke) + return await rate_limited_ainvoke(input, message_history, system_instruction) def invoke_with_tools( self, @@ -95,8 +345,8 @@ def invoke_with_tools( system_instruction: Optional[str] = None, ) -> ToolCallResponse: """Sends a text input to the LLM with tool definitions and retrieves a tool call response. - - This is a default implementation that should be overridden by LLM providers that support tool/function calling. + + Rate limiting is applied automatically if configured. Args: input (str): Text sent to the LLM. @@ -112,7 +362,8 @@ def invoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ - raise NotImplementedError("This LLM provider does not support tool calling.") + rate_limited_invoke_with_tools = self._apply_rate_limiting(self._invoke_with_tools) + return rate_limited_invoke_with_tools(input, tools, message_history, system_instruction) async def ainvoke_with_tools( self, @@ -122,8 +373,8 @@ async def ainvoke_with_tools( system_instruction: Optional[str] = None, ) -> ToolCallResponse: """Asynchronously sends a text input to the LLM with tool definitions and retrieves a tool call response. - - This is a default implementation that should be overridden by LLM providers that support tool/function calling. + + Rate limiting is applied automatically if configured. Args: input (str): Text sent to the LLM. @@ -139,4 +390,5 @@ async def ainvoke_with_tools( LLMGenerationError: If anything goes wrong. NotImplementedError: If the LLM provider does not support tool calling. """ - raise NotImplementedError("This LLM provider does not support tool calling.") + rate_limited_ainvoke_with_tools = self._apply_async_rate_limiting(self._ainvoke_with_tools) + return await rate_limited_ainvoke_with_tools(input, tools, message_history, system_instruction) diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 1e0228e45..96adff413 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -50,10 +50,18 @@ ) from neo4j_graphrag.tool import Tool +from .rate_limiter import BaseRateLimiter, RetryConfig if TYPE_CHECKING: import openai +# Try to import tiktoken for precise token counting +try: + import tiktoken + TIKTOKEN_AVAILABLE = True +except ImportError: + TIKTOKEN_AVAILABLE = False + class BaseOpenAILLM(LLMInterface, abc.ABC): client: openai.OpenAI @@ -63,6 +71,8 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + rate_limiter: Optional[BaseRateLimiter] = None, + retry_config: Optional[RetryConfig] = None, ): """ Base class for OpenAI LLM. @@ -72,6 +82,8 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. + rate_limiter (Optional[BaseRateLimiter]): Rate limiter to control request frequency. Defaults to None. + retry_config (Optional[RetryConfig]): Configuration for retry behavior on rate limit errors. Defaults to None. """ try: import openai @@ -81,7 +93,16 @@ def __init__( Please install it with `pip install "neo4j-graphrag[openai]"`.""" ) self.openai = openai - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, rate_limiter, retry_config) + + # Initialize tokenizer for precise token counting + self._tokenizer = None + if TIKTOKEN_AVAILABLE: + try: + self._tokenizer = tiktoken.encoding_for_model(model_name) + except KeyError: + # Fallback to cl100k_base for unknown models (most OpenAI models use this) + self._tokenizer = tiktoken.get_encoding("cl100k_base") def get_messages( self, @@ -124,62 +145,102 @@ def _convert_tool_to_openai_format(self, tool: Tool) -> Dict[str, Any]: except AttributeError: raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") - def invoke( + def _invoke_openai( self, input: str, message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: - """Sends a text input to the OpenAI chat completion model - and returns the response's content. - - Args: - input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. - - Returns: - LLMResponse: The response from OpenAI. + """Internal method to call OpenAI API.""" + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + + response = self.client.chat.completions.create( + messages=self.get_messages(input, message_history, system_instruction), + model=self.model_name, + **self.model_params, + ) + + # Update rate limiter with API feedback from headers + if hasattr(response, '_response') and hasattr(response._response, 'headers'): + headers = dict(response._response.headers) + self._update_from_api_response_if_supported(headers) + + # Update token usage if rate limiter supports it + if response.usage and response.usage.total_tokens: + self._update_token_usage_if_supported(response.usage.total_tokens) + + content = response.choices[0].message.content or "" + return LLMResponse(content=content) + + async def _ainvoke_openai( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: + """Internal async method to call OpenAI API.""" + if isinstance(message_history, MessageHistory): + message_history = message_history.messages + + response = await self.async_client.chat.completions.create( + messages=self.get_messages(input, message_history, system_instruction), + model=self.model_name, + **self.model_params, + ) + + # Update rate limiter with API feedback from headers + if hasattr(response, '_response') and hasattr(response._response, 'headers'): + headers = dict(response._response.headers) + await self._aupdate_from_api_response_if_supported(headers) + + # Update token usage if rate limiter supports it + if response.usage and response.usage.total_tokens: + await self._aupdate_token_usage_if_supported(response.usage.total_tokens) + + content = response.choices[0].message.content or "" + return LLMResponse(content=content) + + def _invoke( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: + """Internal method that implements the LLM invocation for OpenAI. + + Rate limiting is handled automatically by the base class. + """ + try: + return self._invoke_openai(input, message_history, system_instruction) + except self.openai.OpenAIError as e: + raise LLMGenerationError(e) - Raises: - LLMGenerationError: If anything goes wrong. + async def _ainvoke( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> LLMResponse: + """Internal async method that implements the LLM invocation for OpenAI. + + Rate limiting is handled automatically by the base class. """ try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - response = self.client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), - model=self.model_name, - **self.model_params, - ) - content = response.choices[0].message.content or "" - return LLMResponse(content=content) + return await self._ainvoke_openai(input, message_history, system_instruction) except self.openai.OpenAIError as e: raise LLMGenerationError(e) - def invoke_with_tools( + def _invoke_with_tools( self, input: str, tools: Sequence[Tool], # Tools definition as a sequence of Tool objects message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> ToolCallResponse: - """Sends a text input to the OpenAI chat completion model with tool definitions - and retrieves a tool call response. - - Args: - input (str): Text sent to the LLM. - tools (List[Tool]): List of Tools for the LLM to choose from. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. - - Returns: - ToolCallResponse: The response from the LLM containing a tool call. - - Raises: - LLMGenerationError: If anything goes wrong. + """Internal method that implements tool calling for OpenAI. + + Rate limiting is handled automatically by the base class. """ try: if isinstance(message_history, MessageHistory): @@ -202,6 +263,15 @@ def invoke_with_tools( tool_choice="auto", **params, ) + + # Update rate limiter with API feedback from headers + if hasattr(response, '_response') and hasattr(response._response, 'headers'): + headers = dict(response._response.headers) + self._update_from_api_response_if_supported(headers) + + # Update token usage if rate limiter supports it + if response.usage and response.usage.total_tokens: + self._update_token_usage_if_supported(response.usage.total_tokens) message = response.choices[0].message @@ -232,62 +302,16 @@ def invoke_with_tools( except self.openai.OpenAIError as e: raise LLMGenerationError(e) - async def ainvoke( - self, - input: str, - message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, - system_instruction: Optional[str] = None, - ) -> LLMResponse: - """Asynchronously sends a text input to the OpenAI chat - completion model and returns the response's content. - - Args: - input (str): Text sent to the LLM. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. - - Returns: - LLMResponse: The response from OpenAI. - - Raises: - LLMGenerationError: If anything goes wrong. - """ - try: - if isinstance(message_history, MessageHistory): - message_history = message_history.messages - response = await self.async_client.chat.completions.create( - messages=self.get_messages(input, message_history, system_instruction), - model=self.model_name, - **self.model_params, - ) - content = response.choices[0].message.content or "" - return LLMResponse(content=content) - except self.openai.OpenAIError as e: - raise LLMGenerationError(e) - - async def ainvoke_with_tools( + async def _ainvoke_with_tools( self, input: str, tools: Sequence[Tool], # Tools definition as a sequence of Tool objects message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, system_instruction: Optional[str] = None, ) -> ToolCallResponse: - """Asynchronously sends a text input to the OpenAI chat completion model with tool definitions - and retrieves a tool call response. - - Args: - input (str): Text sent to the LLM. - tools (List[Tool]): List of Tools for the LLM to choose from. - message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, - with each message having a specific role assigned. - system_instruction (Optional[str]): An option to override the llm system message for this invocation. - - Returns: - ToolCallResponse: The response from the LLM containing a tool call. - - Raises: - LLMGenerationError: If anything goes wrong. + """Internal async method that implements tool calling for OpenAI. + + Rate limiting is handled automatically by the base class. """ try: if isinstance(message_history, MessageHistory): @@ -310,6 +334,15 @@ async def ainvoke_with_tools( tool_choice="auto", **params, ) + + # Update rate limiter with API feedback from headers + if hasattr(response, '_response') and hasattr(response._response, 'headers'): + headers = dict(response._response.headers) + await self._aupdate_from_api_response_if_supported(headers) + + # Update token usage if rate limiter supports it + if response.usage and response.usage.total_tokens: + await self._aupdate_token_usage_if_supported(response.usage.total_tokens) message = response.choices[0].message @@ -341,6 +374,68 @@ async def ainvoke_with_tools( except self.openai.OpenAIError as e: raise LLMGenerationError(e) + def _count_message_tokens(self, messages: Iterable[ChatCompletionMessageParam]) -> int: + """Count tokens in messages using tiktoken for precise counting.""" + if not self._tokenizer: + # Fallback estimation if tiktoken not available + total_chars = sum(len(str(msg.get('content', ''))) for msg in messages) + return max(1, total_chars // 4) # Rough estimate: 4 chars per token + + total_tokens = 0 + + # Count tokens for each message + for message in messages: + # OpenAI chat format overhead: each message has ~4 tokens of overhead + total_tokens += 4 + + # Count content tokens + content = message.get('content', '') + if content: + total_tokens += len(self._tokenizer.encode(str(content))) + + # Count role tokens + role = message.get('role', '') + if role: + total_tokens += len(self._tokenizer.encode(role)) + + # Add 2 tokens for the assistant's reply priming + total_tokens += 2 + + return total_tokens + + def _count_tools_tokens(self, tools: Optional[Sequence[Tool]]) -> int: + """Count tokens used by tool definitions.""" + if not tools or not self._tokenizer: + return 0 + + total_tokens = 0 + for tool in tools: + # Convert tool to OpenAI format and count tokens + tool_dict = self._convert_tool_to_openai_format(tool) + tool_json = json.dumps(tool_dict) + total_tokens += len(self._tokenizer.encode(tool_json)) + + return total_tokens + + def _estimate_input_tokens( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + tools: Optional[Sequence[Tool]] = None, + ) -> int: + """Estimate total input tokens for a request.""" + # Get messages that will be sent + messages = list(self.get_messages(input, message_history, system_instruction)) + + # Count message tokens + message_tokens = self._count_message_tokens(messages) + + # Count tool tokens + tool_tokens = self._count_tools_tokens(tools) + + return message_tokens + tool_tokens + class OpenAILLM(BaseOpenAILLM): def __init__( @@ -358,7 +453,11 @@ def __init__( model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + # Extract rate limiting parameters from kwargs if present + rate_limiter = kwargs.pop('rate_limiter', None) + retry_config = kwargs.pop('retry_config', None) + + super().__init__(model_name, model_params, rate_limiter, retry_config) self.client = self.openai.OpenAI(**kwargs) self.async_client = self.openai.AsyncOpenAI(**kwargs) @@ -379,6 +478,10 @@ def __init__( model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + # Extract rate limiting parameters from kwargs if present + rate_limiter = kwargs.pop('rate_limiter', None) + retry_config = kwargs.pop('retry_config', None) + + super().__init__(model_name, model_params, rate_limiter, retry_config) self.client = self.openai.AzureOpenAI(**kwargs) self.async_client = self.openai.AsyncAzureOpenAI(**kwargs) diff --git a/src/neo4j_graphrag/llm/rate_limiter.py b/src/neo4j_graphrag/llm/rate_limiter.py new file mode 100644 index 000000000..29badbe75 --- /dev/null +++ b/src/neo4j_graphrag/llm/rate_limiter.py @@ -0,0 +1,592 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import random +import threading +import time +from abc import ABC, abstractmethod +from functools import wraps +from typing import Any, Callable, Optional, TypeVar, Dict + +from ..exceptions import LLMGenerationError + +logger = logging.getLogger(__name__) + +# Type variable for function return types +T = TypeVar("T") + + +class BaseRateLimiter(ABC): + """Base class for rate limiters.""" + + @abstractmethod + def acquire(self, slots: int = 1) -> bool: + """Attempt to acquire slots from the rate limiter. + + Args: + slots: Number of slots to acquire (default: 1) + + Returns: + True if slots were acquired, False otherwise + """ + + @abstractmethod + async def aacquire(self, slots: int = 1) -> bool: + """Async version of acquire. + + Args: + slots: Number of slots to acquire (default: 1) + + Returns: + True if slots were acquired, False otherwise + """ + + +class TokenTrackingRateLimiter(BaseRateLimiter): + """Base class for rate limiters that can track actual token usage after API calls.""" + + def update_token_usage(self, tokens_used: int) -> None: + """Update the rate limiter with actual token usage from API response. + + Args: + tokens_used: Number of tokens actually consumed by the API call + """ + pass + + async def aupdate_token_usage(self, tokens_used: int) -> None: + """Async version of update_token_usage. + + Args: + tokens_used: Number of tokens actually consumed by the API call + """ + pass + + +class SlotBucketRateLimiter(BaseRateLimiter): + """Slot bucket rate limiter implementation. + + This rate limiter uses a slot bucket algorithm to control the rate of requests. + Slots are added to the bucket at a constant rate, and each request consumes slots. + + Args: + requests_per_second: Maximum number of requests per second + max_bucket_size: Maximum number of slots the bucket can hold + initial_slots: Initial number of slots in the bucket + """ + + def __init__( + self, + requests_per_second: float = 1.0, + max_bucket_size: Optional[int] = None, + initial_slots: Optional[int] = None, + ): + if requests_per_second <= 0: + raise ValueError("requests_per_second must be positive") + + self.requests_per_second = requests_per_second + self.max_bucket_size = max_bucket_size or int(requests_per_second) + self.slots = initial_slots or self.max_bucket_size + self.last_update = time.time() + + # Synchronous lock for sync methods + self._lock = threading.Lock() + + # Async lock for async methods + self._async_lock = asyncio.Lock() + + def _add_slots(self) -> None: + """Add slots to the bucket based on elapsed time.""" + now = time.time() + elapsed = now - self.last_update + slots_to_add = elapsed * self.requests_per_second + + if slots_to_add > 0: + self.slots = min(self.max_bucket_size, self.slots + slots_to_add) + self.last_update = now + + def acquire(self, slots: int = 1) -> bool: + """Attempt to acquire slots from the bucket.""" + with self._lock: + self._add_slots() + if self.slots >= slots: + self.slots -= slots + return True + return False + + async def aacquire(self, slots: int = 1) -> bool: + """Async version of acquire using proper async lock.""" + async with self._async_lock: + self._add_slots() + if self.slots >= slots: + self.slots -= slots + return True + return False + + +class TokenBucketRateLimiter(TokenTrackingRateLimiter): + """Token bucket rate limiter implementation. + + This rate limiter tracks actual token consumption from API responses + and manages rate limits based on tokens per second/minute. + + Args: + tokens_per_second: Maximum number of tokens per second + max_bucket_size: Maximum number of tokens the bucket can hold + initial_tokens: Initial number of tokens in the bucket + estimated_tokens_per_request: Estimated tokens per request for pre-flight checks + """ + + def __init__( + self, + tokens_per_second: float = 1000.0, # Default: 1000 tokens/second + max_bucket_size: Optional[int] = None, + initial_tokens: Optional[int] = None, + estimated_tokens_per_request: int = 500, # Conservative estimate + ): + if tokens_per_second <= 0: + raise ValueError("tokens_per_second must be positive") + if estimated_tokens_per_request <= 0: + raise ValueError("estimated_tokens_per_request must be positive") + + self.tokens_per_second = tokens_per_second + self.max_bucket_size = max_bucket_size or int(tokens_per_second * 60) # 1 minute worth + self.tokens = initial_tokens or self.max_bucket_size + self.last_update = time.time() + self.estimated_tokens_per_request = estimated_tokens_per_request + + # Synchronous lock for sync methods + self._lock = threading.Lock() + + # Async lock for async methods + self._async_lock = asyncio.Lock() + + def _add_tokens(self) -> None: + """Add tokens to the bucket based on elapsed time.""" + now = time.time() + elapsed = now - self.last_update + tokens_to_add = elapsed * self.tokens_per_second + + if tokens_to_add > 0: + self.tokens = min(self.max_bucket_size, self.tokens + tokens_to_add) + self.last_update = now + + def acquire(self, slots: int = 1) -> bool: + """Attempt to acquire tokens for estimated usage. + + Args: + slots: Number of requests (uses estimated_tokens_per_request) + + Returns: + True if estimated tokens were acquired, False otherwise + """ + tokens_needed = slots * self.estimated_tokens_per_request + + with self._lock: + self._add_tokens() + if self.tokens >= tokens_needed: + self.tokens -= tokens_needed + return True + return False + + async def aacquire(self, slots: int = 1) -> bool: + """Async version of acquire using proper async lock.""" + tokens_needed = slots * self.estimated_tokens_per_request + + async with self._async_lock: + self._add_tokens() + if self.tokens >= tokens_needed: + self.tokens -= tokens_needed + return True + return False + + def update_token_usage(self, tokens_used: int) -> None: + """Update with actual token usage and adjust bucket. + + This should be called after API response to correct for estimation errors. + + Args: + tokens_used: Actual tokens consumed by the API call + """ + with self._lock: + # Calculate the difference between estimated and actual usage + estimated_used = self.estimated_tokens_per_request + token_diff = tokens_used - estimated_used + + # Adjust the bucket: if we used more than estimated, deduct more + # if we used less than estimated, give some back + self.tokens = max(0, min(self.max_bucket_size, self.tokens - token_diff)) + + async def aupdate_token_usage(self, tokens_used: int) -> None: + """Async version of update_token_usage.""" + async with self._async_lock: + # Calculate the difference between estimated and actual usage + estimated_used = self.estimated_tokens_per_request + token_diff = tokens_used - estimated_used + + # Adjust the bucket: if we used more than estimated, deduct more + # if we used less than estimated, give some back + self.tokens = max(0, min(self.max_bucket_size, self.tokens - token_diff)) + + +class RetryConfig: + """Configuration for retry behavior when rate limits are hit. + + Args: + max_retries: Maximum number of retry attempts + base_delay: Base delay in seconds for exponential backoff + max_delay: Maximum delay in seconds + exponential_base: Base for exponential backoff calculation + jitter: Whether to add random jitter to delays + """ + + def __init__( + self, + max_retries: int = 3, + base_delay: float = 1.0, + max_delay: float = 60.0, + exponential_base: float = 2.0, + jitter: bool = True, + ): + self.max_retries = max_retries + self.base_delay = base_delay + self.max_delay = max_delay + self.exponential_base = exponential_base + self.jitter = jitter + + def get_delay(self, attempt: int) -> float: + """Calculate delay for the given attempt number.""" + delay = self.base_delay * (self.exponential_base ** attempt) + delay = min(delay, self.max_delay) + + if self.jitter: + # Add jitter: ±25% of the delay + jitter_amount = delay * 0.25 + delay += random.uniform(-jitter_amount, jitter_amount) + + return max(0, delay) + + +def is_rate_limit_error(error: Exception) -> bool: + """Check if an error is a rate limit error. + + Args: + error: Exception to check + + Returns: + True if the error appears to be a rate limit error + """ + error_str = str(error).lower() + rate_limit_indicators = [ + "rate limit", + "too many requests", + "429", + "quota exceeded", + "rate_limit_exceeded", + "throttled", + "rate limiting", + ] + + return any(indicator in error_str for indicator in rate_limit_indicators) + + +class CompositeRateLimiter(TokenTrackingRateLimiter): + """Composite rate limiter that enforces multiple rate limits simultaneously. + + This is useful for providers that have both request-based AND token-based limits. + For example, OpenAI has both "requests per minute" and "tokens per minute" limits. + + A request is only allowed if ALL underlying rate limiters approve it. + + Args: + rate_limiters: List of rate limiters that must all approve requests + """ + + def __init__(self, *rate_limiters: BaseRateLimiter): + if not rate_limiters: + raise ValueError("At least one rate limiter must be provided") + self.rate_limiters = list(rate_limiters) + + def acquire(self, slots: int = 1) -> bool: + """Attempt to acquire slots from ALL rate limiters. + + Returns True only if ALL rate limiters approve the request. + If any rate limiter rejects, we need to "rollback" any that already approved. + """ + approved_limiters = [] + + # Try to acquire from each limiter + for limiter in self.rate_limiters: + if limiter.acquire(slots): + approved_limiters.append(limiter) + else: + # Rollback: return tokens/slots to limiters that already approved + self._rollback_acquisitions(approved_limiters, slots) + return False + + # All limiters approved + return True + + async def aacquire(self, slots: int = 1) -> bool: + """Async version of acquire from ALL rate limiters.""" + approved_limiters = [] + + # Try to acquire from each limiter + for limiter in self.rate_limiters: + if await limiter.aacquire(slots): + approved_limiters.append(limiter) + else: + # Rollback: return tokens/slots to limiters that already approved + await self._arollback_acquisitions(approved_limiters, slots) + return False + + # All limiters approved + return True + + def _rollback_acquisitions(self, approved_limiters: list, slots: int) -> None: + """Rollback acquisitions from limiters that already approved.""" + for limiter in approved_limiters: + # Add back the tokens/slots we took + if hasattr(limiter, '_lock') and hasattr(limiter, 'slots'): + # SlotBucketRateLimiter + with limiter._lock: + limiter.slots += slots + elif hasattr(limiter, '_lock') and hasattr(limiter, 'tokens'): + # TokenBucketRateLimiter + with limiter._lock: + # For token limiters, we need to add back the estimated tokens + if hasattr(limiter, 'estimated_tokens_per_request'): + limiter.tokens += slots * limiter.estimated_tokens_per_request + else: + limiter.tokens += slots + + async def _arollback_acquisitions(self, approved_limiters: list, slots: int) -> None: + """Async rollback acquisitions from limiters that already approved.""" + for limiter in approved_limiters: + # Add back the tokens/slots we took + if hasattr(limiter, '_async_lock') and hasattr(limiter, 'slots'): + # SlotBucketRateLimiter + async with limiter._async_lock: + limiter.slots += slots + elif hasattr(limiter, '_async_lock') and hasattr(limiter, 'tokens'): + # TokenBucketRateLimiter + async with limiter._async_lock: + # For token limiters, we need to add back the estimated tokens + if hasattr(limiter, 'estimated_tokens_per_request'): + limiter.tokens += slots * limiter.estimated_tokens_per_request + else: + limiter.tokens += slots + + def update_token_usage(self, tokens_used: int) -> None: + """Update token usage for all underlying token-tracking rate limiters.""" + for limiter in self.rate_limiters: + if isinstance(limiter, TokenTrackingRateLimiter): + limiter.update_token_usage(tokens_used) + + async def aupdate_token_usage(self, tokens_used: int) -> None: + """Async update token usage for all underlying token-tracking rate limiters.""" + for limiter in self.rate_limiters: + if isinstance(limiter, TokenTrackingRateLimiter): + await limiter.aupdate_token_usage(tokens_used) + + +class APIFeedbackRateLimiter(TokenTrackingRateLimiter): + """Rate limiter that syncs with actual API response headers. + + This rate limiter can consume real-time rate limiting information from + API providers (like OpenAI) to maintain accurate state about remaining + quotas and reset times. + + Args: + fallback_requests_per_second: Fallback rate if no API feedback available + fallback_tokens_per_second: Fallback token rate if no API feedback available + estimated_tokens_per_request: Estimated tokens per request for pre-flight + """ + + def __init__( + self, + fallback_requests_per_second: float = 8.0, + fallback_tokens_per_second: float = 500.0, + estimated_tokens_per_request: int = 400, + ): + self.fallback_requests_per_second = fallback_requests_per_second + self.fallback_tokens_per_second = fallback_tokens_per_second + self.estimated_tokens_per_request = estimated_tokens_per_request + + # API feedback state + self.remaining_requests: Optional[int] = None + self.remaining_tokens: Optional[int] = None + self.request_reset_time: Optional[float] = None + self.token_reset_time: Optional[float] = None + self.limit_requests: Optional[int] = None + self.limit_tokens: Optional[int] = None + + # Last API feedback timestamp + self.last_feedback_time: Optional[float] = None + + # Fallback to regular bucket behavior when no API feedback + self.request_slots = 10 # Conservative initial slots + self.token_slots = 2000 # Conservative initial tokens + + # Locks + self._lock = threading.Lock() + self._async_lock = asyncio.Lock() + + def _is_feedback_stale(self, max_age_seconds: float = 30.0) -> bool: + """Check if API feedback is too old to be reliable.""" + if self.last_feedback_time is None: + return True + return time.time() - self.last_feedback_time > max_age_seconds + + def _fallback_to_bucket_logic(self) -> None: + """Fallback to bucket logic when API feedback is unavailable/stale.""" + now = time.time() + + # Add request slots over time + if hasattr(self, '_last_request_update'): + elapsed = now - self._last_request_update + slots_to_add = elapsed * self.fallback_requests_per_second + self.request_slots = min(20, self.request_slots + slots_to_add) + self._last_request_update = now + + # Add token slots over time + if hasattr(self, '_last_token_update'): + elapsed = now - self._last_token_update + tokens_to_add = elapsed * self.fallback_tokens_per_second + self.token_slots = min(5000, self.token_slots + tokens_to_add) + self._last_token_update = now + + def acquire(self, slots: int = 1) -> bool: + """Acquire slots using API feedback or fallback logic.""" + with self._lock: + now = time.time() + + # Check if we have fresh API feedback + if not self._is_feedback_stale() and self.remaining_requests is not None: + # Use API feedback + tokens_needed = slots * self.estimated_tokens_per_request + + # Check both request and token limits + if (self.remaining_requests >= slots and + (self.remaining_tokens is None or self.remaining_tokens >= tokens_needed)): + + # Optimistically decrease our tracking + self.remaining_requests -= slots + if self.remaining_tokens is not None: + self.remaining_tokens -= tokens_needed + + return True + else: + return False + else: + # Fallback to bucket logic + self._fallback_to_bucket_logic() + + tokens_needed = slots * self.estimated_tokens_per_request + if self.request_slots >= slots and self.token_slots >= tokens_needed: + self.request_slots -= slots + self.token_slots -= tokens_needed + return True + else: + return False + + async def aacquire(self, slots: int = 1) -> bool: + """Async version of acquire.""" + async with self._async_lock: + # For simplicity, delegate to sync version since it's mostly calculations + return self.acquire(slots) + + def update_from_api_response(self, headers: Dict[str, str]) -> None: + """Update rate limiter state from API response headers. + + Args: + headers: Response headers from the API call + """ + with self._lock: + self.last_feedback_time = time.time() + + # Extract OpenAI-style headers + self.remaining_requests = self._parse_header_int(headers, 'x-ratelimit-remaining-requests') + self.remaining_tokens = self._parse_header_int(headers, 'x-ratelimit-remaining-tokens') + self.limit_requests = self._parse_header_int(headers, 'x-ratelimit-limit-requests') + self.limit_tokens = self._parse_header_int(headers, 'x-ratelimit-limit-tokens') + + # Parse reset times (Unix timestamps) + self.request_reset_time = self._parse_header_float(headers, 'x-ratelimit-reset-requests') + self.token_reset_time = self._parse_header_float(headers, 'x-ratelimit-reset-tokens') + + async def aupdate_from_api_response(self, headers: Dict[str, str]) -> None: + """Async version of update_from_api_response.""" + async with self._async_lock: + self.update_from_api_response(headers) + + def _parse_header_int(self, headers: Dict[str, str], key: str) -> Optional[int]: + """Parse integer header value.""" + value = headers.get(key) + if value is not None: + try: + return int(value) + except ValueError: + pass + return None + + def _parse_header_float(self, headers: Dict[str, str], key: str) -> Optional[float]: + """Parse float header value.""" + value = headers.get(key) + if value is not None: + try: + return float(value) + except ValueError: + pass + return None + + def update_token_usage(self, tokens_used: int) -> None: + """Update with actual token usage (for compatibility).""" + # With API feedback, this is less important since we get real updates + # But we can still adjust our optimistic tracking + with self._lock: + if self.remaining_tokens is not None: + # Correct our optimistic estimate + estimated_used = self.estimated_tokens_per_request + correction = tokens_used - estimated_used + self.remaining_tokens = max(0, self.remaining_tokens - correction) + + async def aupdate_token_usage(self, tokens_used: int) -> None: + """Async version of update_token_usage.""" + async with self._async_lock: + self.update_token_usage(tokens_used) + + def get_status(self) -> Dict[str, any]: + """Get current rate limiter status for debugging.""" + now = time.time() + return { + 'remaining_requests': self.remaining_requests, + 'remaining_tokens': self.remaining_tokens, + 'request_reset_time': self.request_reset_time, + 'token_reset_time': self.token_reset_time, + 'request_reset_in_seconds': ( + self.request_reset_time - now if self.request_reset_time else None + ), + 'token_reset_in_seconds': ( + self.token_reset_time - now if self.token_reset_time else None + ), + 'limit_requests': self.limit_requests, + 'limit_tokens': self.limit_tokens, + 'has_fresh_feedback': not self._is_feedback_stale(), + 'last_feedback_age_seconds': ( + now - self.last_feedback_time if self.last_feedback_time else None + ), + }