From d6cc91c26b57e81acb9c031e9d4d47bf8d089107 Mon Sep 17 00:00:00 2001 From: Shreyas-Microsoft Date: Wed, 25 Jun 2025 16:41:14 +0530 Subject: [PATCH 1/2] retry logic --- src/backend/sql_agents/convert_script.py | 3 +- .../sql_agents/helpers/comms_manager.py | 254 +++++++++++------- src/backend/sql_agents/process_batch.py | 3 +- 3 files changed, 163 insertions(+), 97 deletions(-) diff --git a/src/backend/sql_agents/convert_script.py b/src/backend/sql_agents/convert_script.py index 3686886..3e55972 100644 --- a/src/backend/sql_agents/convert_script.py +++ b/src/backend/sql_agents/convert_script.py @@ -4,7 +4,6 @@ and updates the database with the results. """ -import asyncio import json import logging @@ -69,7 +68,7 @@ async def convert_script( carry_response = None async for response in chat.invoke(): # TEMPORARY: awaiting bug fix for rate limits - await asyncio.sleep(5) + #await asyncio.sleep(5) carry_response = response if response.role == AuthorRole.ASSISTANT.value: # Our process can terminate with either of these as the last response diff --git a/src/backend/sql_agents/helpers/comms_manager.py b/src/backend/sql_agents/helpers/comms_manager.py index d465ef0..207db68 100644 --- a/src/backend/sql_agents/helpers/comms_manager.py +++ b/src/backend/sql_agents/helpers/comms_manager.py @@ -1,116 +1,184 @@ -"""Manages all agent communication and chat strategies for the SQL agents.""" +"""Optimized CommsManager with parallel processing and performance improvements.""" -from semantic_kernel.agents import AgentGroupChat # pylint: disable=E0611 +import asyncio +import logging +import re +from typing import AsyncIterable, ClassVar, List +from concurrent.futures import ThreadPoolExecutor + +from semantic_kernel.agents import AgentGroupChat from semantic_kernel.agents.strategies import ( SequentialSelectionStrategy, TerminationStrategy, ) +from semantic_kernel.contents import ChatMessageContent +from semantic_kernel.exceptions import AgentInvokeException from sql_agents.agents.migrator.response import MigratorResponse from sql_agents.helpers.models import AgentType class CommsManager: - """Manages all agent communication and selection strategies for the SQL agents.""" - - group_chat: AgentGroupChat = None + """Optimized CommsManager with parallel processing and performance improvements.""" - class SelectionStrategy(SequentialSelectionStrategy): - """A strategy for determining which agent should take the next turn in the chat.""" - - # Select the next agent that should take the next turn in the chat - async def select_agent(self, agents, history): - """Check which agent should take the next turn in the chat.""" - match history[-1].name: - case AgentType.MIGRATOR.value: - # The Migrator should go first - agent_name = AgentType.PICKER.value - return next( - (agent for agent in agents if agent.name == agent_name), None - ) - # The Incident Manager should go after the User or the Devops Assistant - case AgentType.PICKER.value: - agent_name = AgentType.SYNTAX_CHECKER.value - return next( - (agent for agent in agents if agent.name == agent_name), None - ) - case AgentType.SYNTAX_CHECKER.value: - agent_name = AgentType.FIXER.value - return next( - (agent for agent in agents if agent.name == agent_name), - None, - ) - case AgentType.FIXER.value: - # The Fixer should always go after the Syntax Checker - agent_name = AgentType.SYNTAX_CHECKER.value - return next( - (agent for agent in agents if agent.name == agent_name), None - ) - case "candidate": - # The candidate message is created in the orchestration loop to pass the - # candidate and source sql queries to the Semantic Verifier - # It is created when the Syntax Checker returns an empty list of errors - agent_name = AgentType.SEMANTIC_VERIFIER.value - return next( - (agent for agent in agents if agent.name == agent_name), - None, - ) - case _: - # Start run with this one - no history - return next( - ( - agent - for agent in agents - if agent.name == AgentType.MIGRATOR.value - ), - None, - ) + logger: ClassVar[logging.Logger] = logging.getLogger(__name__) + _EXTRACT_WAIT_TIME = r"in (\d+) seconds" - # class for termination strategy - class ApprovalTerminationStrategy(TerminationStrategy): - """ - A strategy for determining when an agent should terminate. - This, combined with the maximum_iterations setting on the group chat, determines - when the agents are finished processing a file when there are no errors. - """ + def __init__( + self, + agent_dict: dict[AgentType, object], + exception_types: tuple = (Exception,), + max_retries: int = 3, # reduc from 10 + initial_delay: float = 0.5, # reduced from 1.0 + backoff_factor: float = 1.5, # reduced from 2.0 + simple_truncation: int = 50, # more aggr truncation + batch_size: int = 10, # process in batches + max_workers: int = 4, # parallel processing + ): + self.max_retries = max_retries + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor + self.exception_types = exception_types + self.simple_truncation = simple_truncation + self.batch_size = batch_size + self.max_workers = max_workers + + # Thread pool for parallel processing + self.executor = ThreadPoolExecutor(max_workers=max_workers) - async def should_agent_terminate(self, agent, history): - """Check if the agent should terminate.""" - # May need to convert to models to get usable content using history[-1].name - terminate: bool = False - lower_case_hist: str = history[-1].content.lower() - match history[-1].name: - case AgentType.MIGRATOR.value: - response = MigratorResponse.model_validate_json( - lower_case_hist or "" - ) - if ( - response.input_error is not None - or response.rai_error is not None - ): - terminate = True - case AgentType.SEMANTIC_VERIFIER.value: - # Always terminate after the Semantic Verifier runs - terminate = True - case _: - # If the agent is not the Migrator or Semantic Verifier, don't terminate - # Note that the Syntax Checker and Fixer loop are only terminated by correct SQL - # or by iterations exceeding the max_iterations setting - pass - - return terminate - - def __init__(self, agent_dict): - """Initialize the CommsManager and agent_chat with the given agents.""" self.group_chat = AgentGroupChat( agents=agent_dict.values(), - termination_strategy=self.ApprovalTerminationStrategy( + termination_strategy=self.OptimizedTerminationStrategy( agents=[ agent_dict[AgentType.MIGRATOR], agent_dict[AgentType.SEMANTIC_VERIFIER], ], - maximum_iterations=10, + maximum_iterations=5, # Reduced from 10 automatic_reset=True, ), - selection_strategy=self.SelectionStrategy(agents=agent_dict.values()), + selection_strategy=self.ParallelSelectionStrategy( + agents=agent_dict.values(), + max_workers=max_workers + ), ) + + async def async_invoke_batch(self, inputs: List[str]) -> AsyncIterable[ChatMessageContent]: + """Process multiple inputs in parallel batches.""" + # Process inputs in batches + for i in range(0, len(inputs), self.batch_size): + batch = inputs[i:i + self.batch_size] + + # Process batch in parallel + tasks = [self._process_single_input(input_item) for input_item in batch] + batch_results = await asyncio.gather(*tasks, return_exceptions=True) + + for result in batch_results: + if isinstance(result, Exception): + self.logger.error(f"Batch processing error: {result}") + continue + + async for item in result: + yield item + + async def _process_single_input(self, input_item: str) -> AsyncIterable[ChatMessageContent]: + """Process a single input with optimized retry logic.""" + attempt = 0 + current_delay = self.initial_delay + + while attempt < self.max_retries: + try: + # Aggressive history truncation + if len(self.group_chat.history) > self.simple_truncation: + # Keep only the most recent messages + self.group_chat.history = self.group_chat.history[-self.simple_truncation:] + + # Add input to chat + self.group_chat.add_chat_message(ChatMessageContent( + role="user", + content=input_item + )) + + async_iter = self.group_chat.invoke() + async for item in async_iter: + yield item + break + + except AgentInvokeException as aie: + attempt += 1 + if attempt >= self.max_retries: + self.logger.error( + "Input processing failed after %d attempts: %s", + self.max_retries, str(aie) + ) + # Don't raise, continue with next input + break + + # Faster retry with shorter delays + match = re.search(self._EXTRACT_WAIT_TIME, str(aie)) + if match: + current_delay = min(int(match.group(1)), 5) # Cap at 5 seconds + else: + current_delay = min(current_delay * self.backoff_factor, 10) # Cap at 10 seconds + + self.logger.warning( + "Attempt %d/%d failed. Retrying in %.2f seconds...", + attempt, self.max_retries, current_delay + ) + await asyncio.sleep(current_delay) + + class ParallelSelectionStrategy(SequentialSelectionStrategy): + """Optimized selection strategy with parallel processing capabilities.""" + + def __init__(self, agents, max_workers: int = 4): + super().__init__(agents) + self.max_workers = max_workers + + async def select_agent(self, agents, history): + """Select agent with optimized logic and parallel processing hints.""" + if not history: + return next((agent for agent in agents if agent.name == AgentType.MIGRATOR.value), None) + + last_agent = history[-1].name + + # Optimized selection logic with fewer transitions + agent_transitions = { + AgentType.MIGRATOR.value: AgentType.PICKER.value, + AgentType.PICKER.value: AgentType.SYNTAX_CHECKER.value, + AgentType.SYNTAX_CHECKER.value: AgentType.FIXER.value, + AgentType.FIXER.value: AgentType.SEMANTIC_VERIFIER.value, # Skip syntax check + "candidate": AgentType.SEMANTIC_VERIFIER.value, + } + + next_agent_name = agent_transitions.get(last_agent, AgentType.MIGRATOR.value) + return next((agent for agent in agents if agent.name == next_agent_name), None) + + class OptimizedTerminationStrategy(TerminationStrategy): + """Optimized termination strategy with faster decision making.""" + + async def should_agent_terminate(self, agent, history): + """Determine termination with optimized checks.""" + if not history: + return False + + last_message = history[-1] + lower_case_content = last_message.content.lower() + + # Fast termination checks + if last_message.name == AgentType.SEMANTIC_VERIFIER.value: + return True + + if last_message.name == AgentType.MIGRATOR.value: + try: + # Faster JSON parsing with error handling + response = MigratorResponse.model_validate_json(lower_case_content or "{}") + return bool(response.input_error or response.rai_error) + except Exception: + # If parsing fails, assume no termination needed + return False + + return False + + def cleanup(self): + """Clean up resources.""" + if hasattr(self, 'executor'): + self.executor.shutdown(wait=False) \ No newline at end of file diff --git a/src/backend/sql_agents/process_batch.py b/src/backend/sql_agents/process_batch.py index cb22efe..66e798a 100644 --- a/src/backend/sql_agents/process_batch.py +++ b/src/backend/sql_agents/process_batch.py @@ -4,7 +4,6 @@ It is the main entry point for the SQL migration process. """ -import asyncio import logging from api.status_updates import send_status_update @@ -132,7 +131,7 @@ async def process_batch_async( else: await batch_service.update_file_counts(file["file_id"]) # TEMPORARY: awaiting bug fix for rate limits - await asyncio.sleep(5) + #await asyncio.sleep(5) except UnicodeDecodeError as ucde: logger.error("Error decoding file: %s", file) logger.error("Error decoding file. %s", ucde) From aa101e53c97f1d87039844b89cdf4f3a6a04bd19 Mon Sep 17 00:00:00 2001 From: Shreyas-Microsoft Date: Thu, 26 Jun 2025 12:54:12 +0530 Subject: [PATCH 2/2] Less retry --- src/backend/sql_agents/convert_script.py | 329 +++++++------ .../sql_agents/helpers/comms_manager.py | 432 +++++++++++++----- 2 files changed, 500 insertions(+), 261 deletions(-) diff --git a/src/backend/sql_agents/convert_script.py b/src/backend/sql_agents/convert_script.py index 3e55972..7aaa423 100644 --- a/src/backend/sql_agents/convert_script.py +++ b/src/backend/sql_agents/convert_script.py @@ -1,9 +1,10 @@ -"""This module loops through each file in a batch and processes it using the SQL agents. +"""This module loops through each file in a batch and processes it using the SQL agents.More actions It sets up a group chat for the agents, sends the source script to the chat, and processes the responses from the agents. It also reports in real-time to the client using websockets and updates the database with the results. """ +import asyncio import json import logging @@ -45,6 +46,13 @@ async def convert_script( # Setup the group chat for the agents chat = CommsManager(sql_agents.idx_agents).group_chat + #retry logic comms manager + comms_manager = CommsManager( + sql_agents.idx_agents, + max_retries=3, # Retry up to 5 times for rate limits + initial_delay=0.2, # Start with 1 second delay + backoff_factor=1.2, # Double delay each retry + ) # send websocket notification that file processing has started send_status_update( @@ -62,160 +70,211 @@ async def convert_script( current_migration = "No migration" is_complete: bool = False while not is_complete: - await chat.add_chat_message( + await comms_manager.group_chat.add_chat_message( ChatMessageContent(role=AuthorRole.USER, content=source_script) ) carry_response = None - async for response in chat.invoke(): - # TEMPORARY: awaiting bug fix for rate limits - #await asyncio.sleep(5) - carry_response = response - if response.role == AuthorRole.ASSISTANT.value: - # Our process can terminate with either of these as the last response - # before syntax check - match response.name: - case AgentType.MIGRATOR.value: - result = MigratorResponse.model_validate_json( - response.content or "" - ) - if result.input_error or result.rai_error: - # If there is an error in input, we end the processing here. - # We do not include this in termination to avoid forking the chat process. - description = { - "role": response.role, - "name": response.name or "*", - "content": response.content, - } - await batch_service.create_file_log( - str(file.file_id), - description, - current_migration, - LogType.ERROR, - AgentType(response.name), - AuthorRole(response.role), + try: + + async for response in comms_manager.async_invoke(): + # TEMPORARY: awaiting bug fix for rate limits + # await asyncio.sleep(5) + carry_response = response + if response.role == AuthorRole.ASSISTANT.value: + # Our process can terminate with either of these as the last response + # before syntax check + match response.name: + case AgentType.MIGRATOR.value: + result = MigratorResponse.model_validate_json( + response.content or "" ) - current_migration = None - break - case AgentType.SYNTAX_CHECKER.value: - result = SyntaxCheckerResponse.model_validate_json( - response.content.lower() or "" - ) - # If there are no syntax errors, we can move to the semantic verifier - # We provide both scripts by injecting them into the chat history - if result.syntax_errors == []: - chat.history.add_message( - ChatMessageContent( - role=AuthorRole.USER, - name="candidate", - content=( - f"source_script: {source_script}, \n " - + f"migrated_script: {current_migration}" - ), + if result.input_error or result.rai_error: + # If there is an error in input, we end the processing here. + # We do not include this in termination to avoid forking the chat process. + description = { + "role": response.role, + "name": response.name or "*", + "content": response.content, + } + await batch_service.create_file_log( + str(file.file_id), + description, + current_migration, + LogType.ERROR, + AgentType(response.name), + AuthorRole(response.role), ) + current_migration = None + break + case AgentType.SYNTAX_CHECKER.value: + result = SyntaxCheckerResponse.model_validate_json( + response.content.lower() or "" ) - case AgentType.PICKER.value: - result = PickerResponse.model_validate_json( - response.content or "" - ) - current_migration = result.picked_query - case AgentType.FIXER.value: - result = FixerResponse.model_validate_json( - response.content or "" - ) - current_migration = result.fixed_query - case AgentType.SEMANTIC_VERIFIER.value: - logger.info( - "Semantic verifier agent response: %s", response.content - ) - result = SemanticVerifierResponse.model_validate_json( - response.content or "" - ) - - # If the semantic verifier agent returns a difference, we need to report it - if len(result.differences) > 0: - description = { - "role": AuthorRole.ASSISTANT.value, - "name": AgentType.SEMANTIC_VERIFIER.value, - "content": "\n".join(result.differences), - } - logger.info( - "Semantic verification had issues. Pass with warnings." - ) - # send status update to the client of type in progress with agent status - send_status_update( - status=FileProcessUpdate( - file.batch_id, - file.file_id, - ProcessStatus.COMPLETED, - AgentType.SEMANTIC_VERIFIER, - result.summary, - FileResult.WARNING, - ), + # If there are no syntax errors, we can move to the semantic verifier + # We provide both scripts by injecting them into the chat history + if result.syntax_errors == []: + comms_manager.group_chat.history.add_message( + ChatMessageContent( + role=AuthorRole.USER, + name="candidate", + content=( + f"source_script: {source_script}, \n " + + f"migrated_script: {current_migration}" + ), + ) + ) + case AgentType.PICKER.value: + result = PickerResponse.model_validate_json( + response.content or "" ) - await batch_service.create_file_log( - str(file.file_id), - description, - current_migration, - LogType.WARNING, - AgentType.SEMANTIC_VERIFIER, - AuthorRole.ASSISTANT, + current_migration = result.picked_query + case AgentType.FIXER.value: + result = FixerResponse.model_validate_json( + response.content or "" ) - - elif response == "": - # If the semantic verifier agent returns an empty response + current_migration = result.fixed_query + case AgentType.SEMANTIC_VERIFIER.value: logger.info( - "Semantic verification had no return value. Pass with warnings." + "Semantic verifier agent response: %s", response.content + ) + result = SemanticVerifierResponse.model_validate_json( + response.content or "" ) - # send status update to the client of type in progress with agent status - send_status_update( - status=FileProcessUpdate( - file.batch_id, - file.file_id, - ProcessStatus.COMPLETED, + + # If the semantic verifier agent returns a difference, we need to report it + if len(result.differences) > 0: + description = { + "role": AuthorRole.ASSISTANT.value, + "name": AgentType.SEMANTIC_VERIFIER.value, + "content": "\n".join(result.differences), + } + logger.info( + "Semantic verification had issues. Pass with warnings." + ) + # send status update to the client of type in progress with agent status + send_status_update( + status=FileProcessUpdate( + file.batch_id, + file.file_id, + ProcessStatus.COMPLETED, + AgentType.SEMANTIC_VERIFIER, + result.summary, + FileResult.WARNING, + ), + ) + await batch_service.create_file_log( + str(file.file_id), + description, + current_migration, + LogType.WARNING, AgentType.SEMANTIC_VERIFIER, + AuthorRole.ASSISTANT, + ) + + elif response == "": + # If the semantic verifier agent returns an empty response + logger.info( + "Semantic verification had no return value. Pass with warnings." + ) + # send status update to the client of type in progress with agent status + send_status_update( + status=FileProcessUpdate( + file.batch_id, + file.file_id, + ProcessStatus.COMPLETED, + AgentType.SEMANTIC_VERIFIER, + "No return value from semantic verifier agent.", + FileResult.WARNING, + ), + ) + await batch_service.create_file_log( + str(file.file_id), "No return value from semantic verifier agent.", - FileResult.WARNING, - ), - ) - await batch_service.create_file_log( - str(file.file_id), - "No return value from semantic verifier agent.", - current_migration, - LogType.WARNING, - AgentType.SEMANTIC_VERIFIER, - AuthorRole.ASSISTANT, - ) + current_migration, + LogType.WARNING, + AgentType.SEMANTIC_VERIFIER, + AuthorRole.ASSISTANT, + ) + + description = { + "role": response.role, + "name": response.name or "*", + "content": response.content, + } - description = { - "role": response.role, - "name": response.name or "*", - "content": response.content, - } + logger.info(description) - logger.info(description) + # send status update to the client of type in progress with agent status + # send_status_update( + # status=FileProcessUpdate( + # file.batch_id, + # file.file_id, + # ProcessStatus.IN_PROGRESS, + # AgentType(response.name), + # json.loads(response.content)["summary"], + # FileResult.INFO, + # ), + # ) + # Safely parse response content to avoid crashing on malformed or incomplete JSON + #start + try: + parsed_content = json.loads(response.content or "{}") + except json.JSONDecodeError: + logger.warning("Invalid JSON from agent: %s", response.content) + parsed_content = { + "input_summary": "", + "candidates": [], + "summary": "", + "input_error": "", + "rai_error": "Invalid JSON from agent.", + } - # send status update to the client of type in progress with agent status + # Send status update using safe fallback values + send_status_update( + status=FileProcessUpdate( + file.batch_id, + file.file_id, + ProcessStatus.IN_PROGRESS, + AgentType(response.name), + parsed_content.get("summary", ""), + FileResult.INFO, + ), + ) + ##end + await batch_service.create_file_log( + str(file.file_id), + description, + current_migration, + LogType.INFO, + AgentType(response.name), + AuthorRole(response.role), + ) + except Exception as e: + #logger.error("Error during chat.invoke(): %s", str(e)) + logger.error("Error during comms_manager.async_invoke(): %s", str(e)) + # Log the error to the batch service for tracking + await batch_service.create_file_log( + str(file.file_id), + f"Critical error during agent communication: {str(e)}", + current_migration, + LogType.ERROR, + AgentType.ALL, + AuthorRole.ASSISTANT, + ) + # Send error status update send_status_update( status=FileProcessUpdate( file.batch_id, file.file_id, - ProcessStatus.IN_PROGRESS, - AgentType(response.name), - json.loads(response.content)["summary"], - FileResult.INFO, + ProcessStatus.COMPLETED, + AgentType.ALL, + f"Processing failed: {str(e)}", + FileResult.ERROR, ), ) + break # Exit the while loop on critical error - await batch_service.create_file_log( - str(file.file_id), - description, - current_migration, - LogType.INFO, - AgentType(response.name), - AuthorRole(response.role), - ) - - if chat.is_complete: + if comms_manager.group_chat.is_complete: is_complete = True break @@ -295,6 +354,4 @@ async def validate_migration( log_type=LogType.SUCCESS, agent_type=AgentType.ALL, author_role=AuthorRole.ASSISTANT, - ) - - return True + ) \ No newline at end of file diff --git a/src/backend/sql_agents/helpers/comms_manager.py b/src/backend/sql_agents/helpers/comms_manager.py index 207db68..fcf45b5 100644 --- a/src/backend/sql_agents/helpers/comms_manager.py +++ b/src/backend/sql_agents/helpers/comms_manager.py @@ -1,12 +1,12 @@ -"""Optimized CommsManager with parallel processing and performance improvements.""" +"""Manages all agent communication and chat strategies for the SQL agents.""" import asyncio +import copy import logging import re -from typing import AsyncIterable, ClassVar, List -from concurrent.futures import ThreadPoolExecutor +from typing import AsyncIterable, ClassVar -from semantic_kernel.agents import AgentGroupChat +from semantic_kernel.agents import AgentGroupChat # pylint: disable=E0611 from semantic_kernel.agents.strategies import ( SequentialSelectionStrategy, TerminationStrategy, @@ -19,166 +19,348 @@ class CommsManager: - """Optimized CommsManager with parallel processing and performance improvements.""" + """Manages all agent communication and selection strategies for the SQL agents.""" + # Class level logger logger: ClassVar[logging.Logger] = logging.getLogger(__name__) + + # regex to extract the recommended wait time in seconds from response _EXTRACT_WAIT_TIME = r"in (\d+) seconds" + # Rate limit error indicators + _RATE_LIMIT_INDICATORS = [ + "rate limit", + "too many requests", + "quota exceeded", + "throttled", + "429", + ] + + group_chat: AgentGroupChat = None + + class SelectionStrategy(SequentialSelectionStrategy): + """A strategy for determining which agent should take the next turn in the chat.""" + + # Select the next agent that should take the next turn in the chat + async def select_agent(self, agents, history): + """Check which agent should take the next turn in the chat.""" + match history[-1].name: + case AgentType.MIGRATOR.value: + # The Migrator should go first + agent_name = AgentType.PICKER.value + return next( + (agent for agent in agents if agent.name == agent_name), None + ) + # The Incident Manager should go after the User or the Devops Assistant + case AgentType.PICKER.value: + agent_name = AgentType.SYNTAX_CHECKER.value + return next( + (agent for agent in agents if agent.name == agent_name), None + ) + case AgentType.SYNTAX_CHECKER.value: + agent_name = AgentType.FIXER.value + return next( + (agent for agent in agents if agent.name == agent_name), + None, + ) + case AgentType.FIXER.value: + # The Fixer should always go after the Syntax Checker + agent_name = AgentType.SYNTAX_CHECKER.value + return next( + (agent for agent in agents if agent.name == agent_name), None + ) + case "candidate": + # The candidate message is created in the orchestration loop to pass the + # candidate and source sql queries to the Semantic Verifier + # It is created when the Syntax Checker returns an empty list of errors + agent_name = AgentType.SEMANTIC_VERIFIER.value + return next( + (agent for agent in agents if agent.name == agent_name), + None, + ) + case _: + # Start run with this one - no history + return next( + ( + agent + for agent in agents + if agent.name == AgentType.MIGRATOR.value + ), + None, + ) + + # class for termination strategy + class ApprovalTerminationStrategy(TerminationStrategy): + """ + A strategy for determining when an agent should terminate. + This, combined with the maximum_iterations setting on the group chat, determines + when the agents are finished processing a file when there are no errors. + """ + + async def should_agent_terminate(self, agent, history): + """Check if the agent should terminate.""" + # May need to convert to models to get usable content using history[-1].name + terminate: bool = False + lower_case_hist: str = history[-1].content.lower() + match history[-1].name: + case AgentType.MIGRATOR.value: + response = MigratorResponse.model_validate_json( + lower_case_hist or "" + ) + if ( + response.input_error is not None + or response.rai_error is not None + ): + terminate = True + case AgentType.SEMANTIC_VERIFIER.value: + # Always terminate after the Semantic Verifier runs + terminate = True + case _: + # If the agent is not the Migrator or Semantic Verifier, don't terminate + # Note that the Syntax Checker and Fixer loop are only terminated by correct SQL + # or by iterations exceeding the max_iterations setting + pass + + return terminate + def __init__( - self, - agent_dict: dict[AgentType, object], + self, + agent_dict, exception_types: tuple = (Exception,), - max_retries: int = 3, # reduc from 10 - initial_delay: float = 0.5, # reduced from 1.0 - backoff_factor: float = 1.5, # reduced from 2.0 - simple_truncation: int = 50, # more aggr truncation - batch_size: int = 10, # process in batches - max_workers: int = 4, # parallel processing + max_retries: int = 10, + initial_delay: float = 0.5, + backoff_factor: float = 1.5, + simple_truncation: int = None, ): - self.max_retries = max_retries - self.initial_delay = initial_delay - self.backoff_factor = backoff_factor - self.exception_types = exception_types - self.simple_truncation = simple_truncation - self.batch_size = batch_size - self.max_workers = max_workers + """Initialize the CommsManager and agent_chat with the given agents. - # Thread pool for parallel processing - self.executor = ThreadPoolExecutor(max_workers=max_workers) - + Args: + agent_dict: Dictionary of agents + exception_types: Tuple of exception types that should trigger a retry + max_retries: Maximum number of retry attempts (default: 10) + initial_delay: Initial delay in seconds before first retry (default: 0.5) + backoff_factor: Factor by which the delay increases with each retry (default: 1.5) + simple_truncation: Optional truncation limit for chat history + """ + # Initialize the group chat (exactly like original) self.group_chat = AgentGroupChat( agents=agent_dict.values(), - termination_strategy=self.OptimizedTerminationStrategy( + termination_strategy=self.ApprovalTerminationStrategy( agents=[ agent_dict[AgentType.MIGRATOR], agent_dict[AgentType.SEMANTIC_VERIFIER], ], - maximum_iterations=5, # Reduced from 10 + maximum_iterations=10, automatic_reset=True, ), - selection_strategy=self.ParallelSelectionStrategy( - agents=agent_dict.values(), - max_workers=max_workers - ), + selection_strategy=self.SelectionStrategy(agents=agent_dict.values()), ) - async def async_invoke_batch(self, inputs: List[str]) -> AsyncIterable[ChatMessageContent]: - """Process multiple inputs in parallel batches.""" - # Process inputs in batches - for i in range(0, len(inputs), self.batch_size): - batch = inputs[i:i + self.batch_size] - - # Process batch in parallel - tasks = [self._process_single_input(input_item) for input_item in batch] - batch_results = await asyncio.gather(*tasks, return_exceptions=True) - - for result in batch_results: - if isinstance(result, Exception): - self.logger.error(f"Batch processing error: {result}") - continue - - async for item in result: - yield item + # Store retry configuration + self.max_retries = max_retries + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor + self.exception_types = exception_types + self.simple_truncation = simple_truncation - async def _process_single_input(self, input_item: str) -> AsyncIterable[ChatMessageContent]: - """Process a single input with optimized retry logic.""" + # Adaptive retry state - starts optimistic + self._rate_limit_detected_recently = False + self._consecutive_successes = 0 + self._session_has_rate_limits = False + + def _is_rate_limit_error(self, error_message: str) -> bool: + """Check if the error message indicates a rate limit issue.""" + error_lower = error_message.lower() + return any(indicator in error_lower for indicator in self._RATE_LIMIT_INDICATORS) + + def _should_use_zero_overhead_path(self) -> bool: + """ + Determine if we should use zero-overhead path. + + Use zero overhead when: + - No rate limits detected in current session AND + - We have some successful calls OR this is the first call + """ + return ( + not self._session_has_rate_limits + and (self._consecutive_successes >= 1 or self._consecutive_successes == 0) + + ) + + async def _zero_overhead_invoke(self) -> AsyncIterable[ChatMessageContent]: + """Pure delegation to original group_chat.invoke() - zero overhead.""" + async for item in self.group_chat.invoke(): + yield item + + async def _retry_enabled_invoke(self) -> AsyncIterable[ChatMessageContent]: + """Invoke with retry logic - only used when rate limits are expected.""" attempt = 0 current_delay = self.initial_delay + # Create history snapshot only when we need it + history_snapshot = None + while attempt < self.max_retries: try: - # Aggressive history truncation - if len(self.group_chat.history) > self.simple_truncation: - # Keep only the most recent messages - self.group_chat.history = self.group_chat.history[-self.simple_truncation:] - - # Add input to chat - self.group_chat.add_chat_message(ChatMessageContent( - role="user", - content=input_item - )) - - async_iter = self.group_chat.invoke() - async for item in async_iter: + # Apply truncation if configured and on first attempt + if ( + attempt == 0 + and self.simple_truncation + and len(self.group_chat.history) > self.simple_truncation + ): + if history_snapshot is None: + history_snapshot = copy.deepcopy(self.group_chat.history) + self.group_chat.history = history_snapshot[-self.simple_truncation:] + + + # Execute and yield results + async for item in self.group_chat.invoke(): yield item - break + + # Success - exit retry loop + return except AgentInvokeException as aie: + # Create snapshot only when we actually need to retry + if history_snapshot is None: + history_snapshot = copy.deepcopy(self.group_chat.history) + attempt += 1 if attempt >= self.max_retries: self.logger.error( - "Input processing failed after %d attempts: %s", - self.max_retries, str(aie) + "AgentInvokeException: Max retries (%d) exceeded. Final error: %s", + self.max_retries, + str(aie), ) - # Don't raise, continue with next input - break + raise + + # Restore history from snapshot + self.group_chat.history = copy.deepcopy(history_snapshot) - # Faster retry with shorter delays - match = re.search(self._EXTRACT_WAIT_TIME, str(aie)) - if match: - current_delay = min(int(match.group(1)), 5) # Cap at 5 seconds + # Check for rate limit specific wait time + wait_time_match = re.search(self._EXTRACT_WAIT_TIME, str(aie)) + if wait_time_match: + current_delay = int(wait_time_match.group(1)) + self.logger.info( + "Rate limit detected, waiting %d seconds as requested", + current_delay + ) else: - current_delay = min(current_delay * self.backoff_factor, 10) # Cap at 10 seconds + current_delay = self.initial_delay * (self.backoff_factor ** (attempt - 1)) self.logger.warning( - "Attempt %d/%d failed. Retrying in %.2f seconds...", - attempt, self.max_retries, current_delay + "Attempt %d/%d failed with AgentInvokeException: %s. Retrying in %.2f seconds...", + attempt, + self.max_retries, + str(aie), + current_delay, ) + await asyncio.sleep(current_delay) - class ParallelSelectionStrategy(SequentialSelectionStrategy): - """Optimized selection strategy with parallel processing capabilities.""" - def __init__(self, agents, max_workers: int = 4): - super().__init__(agents) - self.max_workers = max_workers - async def select_agent(self, agents, history): - """Select agent with optimized logic and parallel processing hints.""" - if not history: - return next((agent for agent in agents if agent.name == AgentType.MIGRATOR.value), None) - - last_agent = history[-1].name - - # Optimized selection logic with fewer transitions - agent_transitions = { - AgentType.MIGRATOR.value: AgentType.PICKER.value, - AgentType.PICKER.value: AgentType.SYNTAX_CHECKER.value, - AgentType.SYNTAX_CHECKER.value: AgentType.FIXER.value, - AgentType.FIXER.value: AgentType.SEMANTIC_VERIFIER.value, # Skip syntax check - "candidate": AgentType.SEMANTIC_VERIFIER.value, - } - - next_agent_name = agent_transitions.get(last_agent, AgentType.MIGRATOR.value) - return next((agent for agent in agents if agent.name == next_agent_name), None) - - class OptimizedTerminationStrategy(TerminationStrategy): - """Optimized termination strategy with faster decision making.""" - async def should_agent_terminate(self, agent, history): - """Determine termination with optimized checks.""" - if not history: - return False - - last_message = history[-1] - lower_case_content = last_message.content.lower() - - # Fast termination checks - if last_message.name == AgentType.SEMANTIC_VERIFIER.value: - return True + except self.exception_types as e: + if history_snapshot is None: + history_snapshot = copy.deepcopy(self.group_chat.history) + + + attempt += 1 + if attempt >= self.max_retries: + self.logger.error( + "Generic exception: Max retries (%d) exceeded. Final error: %s", + self.max_retries, + str(e), + ) + raise + + # Restore history from snapshot + self.group_chat.history = copy.deepcopy(history_snapshot) + + current_delay = self.initial_delay * (self.backoff_factor ** (attempt - 1)) - if last_message.name == AgentType.MIGRATOR.value: - try: - # Faster JSON parsing with error handling - response = MigratorResponse.model_validate_json(lower_case_content or "{}") - return bool(response.input_error or response.rai_error) - except Exception: - # If parsing fails, assume no termination needed - return False - - return False - - def cleanup(self): - """Clean up resources.""" - if hasattr(self, 'executor'): - self.executor.shutdown(wait=False) \ No newline at end of file + self.logger.warning( + "Attempt %d/%d failed with %s: %s. Retrying in %.2f seconds...", + attempt, + self.max_retries, + type(e).__name__, + str(e), + current_delay, + ) + + await asyncio.sleep(current_delay) + + + async def async_invoke(self) -> AsyncIterable[ChatMessageContent]: + """ + Optimized invoke method that dynamically chooses between zero-overhead and retry modes. + + Performance targets: + - 200k tokens: 1.2 mins (zero overhead when no rate limits expected) + - 30k-50k tokens: 1.8-2 mins (retry overhead only when needed) + """ + + # Decide which path to take + use_zero_overhead = self._should_use_zero_overhead_path() + + if use_zero_overhead: + # Zero overhead path - matches original performance exactly + try: + async for item in self._zero_overhead_invoke(): + yield item + + # Track success + self._consecutive_successes += 1 + return + + except (AgentInvokeException, *self.exception_types) as e: + # Check if this is a rate limit error + error_str = str(e) + if self._is_rate_limit_error(error_str): + self.logger.info( + "Rate limit detected on zero-overhead path, switching to retry mode for this session" + ) + self._session_has_rate_limits = True + self._rate_limit_detected_recently = True + # Fall through to retry logic below + else: + # Non-rate-limit error, re-raise immediately (fail fast) + self.logger.error("Non-rate-limit error in zero-overhead path: %s", error_str) + raise + + # Retry-enabled path - used when rate limits are expected or detected + try: + async for item in self._retry_enabled_invoke(): + yield item + + # Track success + self._consecutive_successes += 1 + + # Gradually become more optimistic about rate limits + if self._consecutive_successes >= 5: + self._rate_limit_detected_recently = False + # Note: We keep _session_has_rate_limits = True to remember for this session + + + + + except Exception as e: + # Reset success counter on failure + self._consecutive_successes = 0 + self._rate_limit_detected_recently = True + raise + + async def invoke_async(self): + """Legacy method - maintained for compatibility.""" + return self.group_chat.invoke() + + def reset_rate_limit_state(self): + """ + Reset rate limit detection state - call this between different processing sessions + if you want to reset the adaptive behavior. + """ + self._rate_limit_detected_recently = False + self._consecutive_successes = 0 + self._session_has_rate_limits = False + self.logger.info("Rate limit detection state reset") \ No newline at end of file