diff --git a/src/backend/sql_agents/convert_script.py b/src/backend/sql_agents/convert_script.py index 3686886..6f89c78 100644 --- a/src/backend/sql_agents/convert_script.py +++ b/src/backend/sql_agents/convert_script.py @@ -4,7 +4,6 @@ and updates the database with the results. """ -import asyncio import json import logging @@ -45,7 +44,14 @@ async def convert_script( logger.info("Migrating query: %s\n", source_script) # Setup the group chat for the agents - chat = CommsManager(sql_agents.idx_agents).group_chat + # chat = CommsManager(sql_agents.idx_agents).group_chat + # retry logic comms manager + comms_manager = CommsManager( + sql_agents.idx_agents, + max_retries=5, # Retry up to 5 times for rate limits + initial_delay=1.0, # Start with 1 second delay + backoff_factor=2.0, # Double delay each retry + ) # send websocket notification that file processing has started send_status_update( @@ -63,160 +69,211 @@ async def convert_script( current_migration = "No migration" is_complete: bool = False while not is_complete: - await chat.add_chat_message( + await comms_manager.group_chat.add_chat_message( ChatMessageContent(role=AuthorRole.USER, content=source_script) ) carry_response = None - async for response in chat.invoke(): - # TEMPORARY: awaiting bug fix for rate limits - await asyncio.sleep(5) - carry_response = response - if response.role == AuthorRole.ASSISTANT.value: - # Our process can terminate with either of these as the last response - # before syntax check - match response.name: - case AgentType.MIGRATOR.value: - result = MigratorResponse.model_validate_json( - response.content or "" - ) - if result.input_error or result.rai_error: - # If there is an error in input, we end the processing here. - # We do not include this in termination to avoid forking the chat process. - description = { - "role": response.role, - "name": response.name or "*", - "content": response.content, - } - await batch_service.create_file_log( - str(file.file_id), - description, - current_migration, - LogType.ERROR, - AgentType(response.name), - AuthorRole(response.role), + try: + + async for response in comms_manager.async_invoke(): + # TEMPORARY: awaiting bug fix for rate limits + # await asyncio.sleep(5) + carry_response = response + if response.role == AuthorRole.ASSISTANT.value: + # Our process can terminate with either of these as the last response + # before syntax check + match response.name: + case AgentType.MIGRATOR.value: + result = MigratorResponse.model_validate_json( + response.content or "" ) - current_migration = None - break - case AgentType.SYNTAX_CHECKER.value: - result = SyntaxCheckerResponse.model_validate_json( - response.content.lower() or "" - ) - # If there are no syntax errors, we can move to the semantic verifier - # We provide both scripts by injecting them into the chat history - if result.syntax_errors == []: - chat.history.add_message( - ChatMessageContent( - role=AuthorRole.USER, - name="candidate", - content=( - f"source_script: {source_script}, \n " - + f"migrated_script: {current_migration}" - ), + if result.input_error or result.rai_error: + # If there is an error in input, we end the processing here. + # We do not include this in termination to avoid forking the chat process. + description = { + "role": response.role, + "name": response.name or "*", + "content": response.content, + } + await batch_service.create_file_log( + str(file.file_id), + description, + current_migration, + LogType.ERROR, + AgentType(response.name), + AuthorRole(response.role), ) + current_migration = None + break + case AgentType.SYNTAX_CHECKER.value: + result = SyntaxCheckerResponse.model_validate_json( + response.content.lower() or "" ) - case AgentType.PICKER.value: - result = PickerResponse.model_validate_json( - response.content or "" - ) - current_migration = result.picked_query - case AgentType.FIXER.value: - result = FixerResponse.model_validate_json( - response.content or "" - ) - current_migration = result.fixed_query - case AgentType.SEMANTIC_VERIFIER.value: - logger.info( - "Semantic verifier agent response: %s", response.content - ) - result = SemanticVerifierResponse.model_validate_json( - response.content or "" - ) - - # If the semantic verifier agent returns a difference, we need to report it - if len(result.differences) > 0: - description = { - "role": AuthorRole.ASSISTANT.value, - "name": AgentType.SEMANTIC_VERIFIER.value, - "content": "\n".join(result.differences), - } - logger.info( - "Semantic verification had issues. Pass with warnings." - ) - # send status update to the client of type in progress with agent status - send_status_update( - status=FileProcessUpdate( - file.batch_id, - file.file_id, - ProcessStatus.COMPLETED, - AgentType.SEMANTIC_VERIFIER, - result.summary, - FileResult.WARNING, - ), + # If there are no syntax errors, we can move to the semantic verifier + # We provide both scripts by injecting them into the chat history + if result.syntax_errors == []: + comms_manager.group_chat.history.add_message( + ChatMessageContent( + role=AuthorRole.USER, + name="candidate", + content=( + f"source_script: {source_script}, \n " + + f"migrated_script: {current_migration}" + ), + ) + ) + case AgentType.PICKER.value: + result = PickerResponse.model_validate_json( + response.content or "" ) - await batch_service.create_file_log( - str(file.file_id), - description, - current_migration, - LogType.WARNING, - AgentType.SEMANTIC_VERIFIER, - AuthorRole.ASSISTANT, + current_migration = result.picked_query + case AgentType.FIXER.value: + result = FixerResponse.model_validate_json( + response.content or "" ) - - elif response == "": - # If the semantic verifier agent returns an empty response + current_migration = result.fixed_query + case AgentType.SEMANTIC_VERIFIER.value: logger.info( - "Semantic verification had no return value. Pass with warnings." + "Semantic verifier agent response: %s", response.content ) - # send status update to the client of type in progress with agent status - send_status_update( - status=FileProcessUpdate( - file.batch_id, - file.file_id, - ProcessStatus.COMPLETED, + result = SemanticVerifierResponse.model_validate_json( + response.content or "" + ) + + # If the semantic verifier agent returns a difference, we need to report it + if len(result.differences) > 0: + description = { + "role": AuthorRole.ASSISTANT.value, + "name": AgentType.SEMANTIC_VERIFIER.value, + "content": "\n".join(result.differences), + } + logger.info( + "Semantic verification had issues. Pass with warnings." + ) + # send status update to the client of type in progress with agent status + send_status_update( + status=FileProcessUpdate( + file.batch_id, + file.file_id, + ProcessStatus.COMPLETED, + AgentType.SEMANTIC_VERIFIER, + result.summary, + FileResult.WARNING, + ), + ) + await batch_service.create_file_log( + str(file.file_id), + description, + current_migration, + LogType.WARNING, AgentType.SEMANTIC_VERIFIER, + AuthorRole.ASSISTANT, + ) + + elif response == "": + # If the semantic verifier agent returns an empty response + logger.info( + "Semantic verification had no return value. Pass with warnings." + ) + # send status update to the client of type in progress with agent status + send_status_update( + status=FileProcessUpdate( + file.batch_id, + file.file_id, + ProcessStatus.COMPLETED, + AgentType.SEMANTIC_VERIFIER, + "No return value from semantic verifier agent.", + FileResult.WARNING, + ), + ) + await batch_service.create_file_log( + str(file.file_id), "No return value from semantic verifier agent.", - FileResult.WARNING, - ), - ) - await batch_service.create_file_log( - str(file.file_id), - "No return value from semantic verifier agent.", - current_migration, - LogType.WARNING, - AgentType.SEMANTIC_VERIFIER, - AuthorRole.ASSISTANT, - ) + current_migration, + LogType.WARNING, + AgentType.SEMANTIC_VERIFIER, + AuthorRole.ASSISTANT, + ) - description = { - "role": response.role, - "name": response.name or "*", - "content": response.content, - } + description = { + "role": response.role, + "name": response.name or "*", + "content": response.content, + } - logger.info(description) + logger.info(description) - # send status update to the client of type in progress with agent status + # send status update to the client of type in progress with agent status + # send_status_update( + # status=FileProcessUpdate( + # file.batch_id, + # file.file_id, + # ProcessStatus.IN_PROGRESS, + # AgentType(response.name), + # json.loads(response.content)["summary"], + # FileResult.INFO, + # ), + # ) + # Safely parse response content to avoid crashing on malformed or incomplete JSON + # start + try: + parsed_content = json.loads(response.content or "{}") + except json.JSONDecodeError: + logger.warning("Invalid JSON from agent: %s", response.content) + parsed_content = { + "input_summary": "", + "candidates": [], + "summary": "", + "input_error": "", + "rai_error": "Invalid JSON from agent.", + } + + # Send status update using safe fallback values + send_status_update( + status=FileProcessUpdate( + file.batch_id, + file.file_id, + ProcessStatus.IN_PROGRESS, + AgentType(response.name), + parsed_content.get("summary", ""), + FileResult.INFO, + ), + ) + # end + await batch_service.create_file_log( + str(file.file_id), + description, + current_migration, + LogType.INFO, + AgentType(response.name), + AuthorRole(response.role), + ) + except Exception as e: + # logger.error("Error during chat.invoke(): %s", str(e)) + logger.error("Error during comms_manager.async_invoke(): %s", str(e)) + # Log the error to the batch service for tracking + await batch_service.create_file_log( + str(file.file_id), + f"Critical error during agent communication: {str(e)}", + current_migration, + LogType.ERROR, + AgentType.ALL, + AuthorRole.ASSISTANT, + ) + # Send error status update send_status_update( status=FileProcessUpdate( file.batch_id, file.file_id, - ProcessStatus.IN_PROGRESS, - AgentType(response.name), - json.loads(response.content)["summary"], - FileResult.INFO, + ProcessStatus.COMPLETED, + AgentType.ALL, + f"Processing failed: {str(e)}", + FileResult.ERROR, ), ) + break # Exit the while loop on critical error - await batch_service.create_file_log( - str(file.file_id), - description, - current_migration, - LogType.INFO, - AgentType(response.name), - AuthorRole(response.role), - ) - - if chat.is_complete: + if comms_manager.group_chat.is_complete: is_complete = True break diff --git a/src/backend/sql_agents/helpers/comms_manager.py b/src/backend/sql_agents/helpers/comms_manager.py index d465ef0..77580c3 100644 --- a/src/backend/sql_agents/helpers/comms_manager.py +++ b/src/backend/sql_agents/helpers/comms_manager.py @@ -1,10 +1,18 @@ """Manages all agent communication and chat strategies for the SQL agents.""" +import asyncio +import copy +import logging +import re +from typing import AsyncIterable, ClassVar + from semantic_kernel.agents import AgentGroupChat # pylint: disable=E0611 from semantic_kernel.agents.strategies import ( SequentialSelectionStrategy, TerminationStrategy, ) +from semantic_kernel.contents import ChatMessageContent +from semantic_kernel.exceptions import AgentInvokeException from sql_agents.agents.migrator.response import MigratorResponse from sql_agents.helpers.models import AgentType @@ -13,6 +21,12 @@ class CommsManager: """Manages all agent communication and selection strategies for the SQL agents.""" + # Class level logger + logger: ClassVar[logging.Logger] = logging.getLogger(__name__) + + # regex to extract the recommended wait time in seconds from response + _EXTRACT_WAIT_TIME = r"in (\d+) seconds" + group_chat: AgentGroupChat = None class SelectionStrategy(SequentialSelectionStrategy): @@ -100,8 +114,33 @@ async def should_agent_terminate(self, agent, history): return terminate - def __init__(self, agent_dict): - """Initialize the CommsManager and agent_chat with the given agents.""" + def __init__( + self, + agent_dict, + exception_types: tuple = (Exception,), + max_retries: int = 10, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, + simple_truncation: int = None, + ): + """Initialize the CommsManager and agent_chat with the given agents. + + Args: + agent_dict: Dictionary of agents + exception_types: Tuple of exception types that should trigger a retry + max_retries: Maximum number of retry attempts (default: 10) + initial_delay: Initial delay in seconds before first retry (default: 1.0) + backoff_factor: Factor by which the delay increases with each retry (default: 2.0) + simple_truncation: Optional truncation limit for chat history + """ + # Store retry configuration + self.max_retries = max_retries + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor + self.exception_types = exception_types + self.simple_truncation = simple_truncation + + # Initialize the group chat (maintaining original functionality) self.group_chat = AgentGroupChat( agents=agent_dict.values(), termination_strategy=self.ApprovalTerminationStrategy( @@ -114,3 +153,108 @@ def __init__(self, agent_dict): ), selection_strategy=self.SelectionStrategy(agents=agent_dict.values()), ) + + async def invoke_async(self): + """Invoke the group chat with the given agents (original method maintained for compatibility).""" + return self.group_chat.invoke() + + async def async_invoke(self) -> AsyncIterable[ChatMessageContent]: + """Invoke the group chat with retry logic and error handling.""" + attempt = 0 + current_delay = self.initial_delay + + while attempt < self.max_retries: + try: + # Grab a snapshot of the history of the group chat + # Using "SHALLOW" copy to avoid getting a reference to the original list + history_snap = copy.copy(self.group_chat.history) + + self.logger.debug( + "History before invoke: %s", + [msg.name for msg in self.group_chat.history], + ) + + # Get a fresh iterator from the function + async_iter = self.group_chat.invoke() + + # If simple truncation is set, truncate the history + if ( + self.simple_truncation + and len(self.group_chat.history) > self.simple_truncation + ): + # Truncate the history to the last n messages + self.group_chat.history = history_snap[-self.simple_truncation :] + + # Yield each item from the iterator + async for item in async_iter: + yield item + + # If we get here without exception, we're done + break + + except AgentInvokeException as aie: + attempt += 1 + if attempt >= self.max_retries: + self.logger.error( + "Function invoke failed after %d attempts. Final error: %s. Consider increasing the models rate limit.", + self.max_retries, + str(aie), + ) + # Re-raise the last exception if all retries failed + raise + + # Return history state for retry + self.group_chat.history = history_snap + + try: + # Try to extract wait time from error message + wait_time_match = re.search(self._EXTRACT_WAIT_TIME, str(aie)) + if wait_time_match: + # If regex is found, set the delay to the value in seconds + current_delay = int(wait_time_match.group(1)) + else: + current_delay = self.initial_delay + + self.logger.warning( + "Attempt %d/%d for function invoke failed: %s. Retrying in %.2f seconds...", + attempt, + self.max_retries, + str(aie), + current_delay, + ) + + # Wait before retrying + await asyncio.sleep(current_delay) + + if not wait_time_match: + # Increase delay for next attempt using backoff factor + current_delay *= self.backoff_factor + + except Exception as ex: + self.logger.error( + "Retry error: %s. Using default delay.", + ex, + ) + current_delay = self.initial_delay + + except self.exception_types as e: + attempt += 1 + if attempt >= self.max_retries: + self.logger.error( + "Function invoke failed after %d attempts. Final error: %s", + self.max_retries, + str(e), + ) + raise + + self.logger.warning( + "Attempt %d/%d failed with %s: %s. Retrying in %.2f seconds...", + attempt, + self.max_retries, + type(e).__name__, + str(e), + current_delay, + ) + + await asyncio.sleep(current_delay) + current_delay *= self.backoff_factor diff --git a/src/backend/sql_agents/process_batch.py b/src/backend/sql_agents/process_batch.py index cb22efe..cd2452f 100644 --- a/src/backend/sql_agents/process_batch.py +++ b/src/backend/sql_agents/process_batch.py @@ -4,7 +4,6 @@ It is the main entry point for the SQL migration process. """ -import asyncio import logging from api.status_updates import send_status_update @@ -132,7 +131,7 @@ async def process_batch_async( else: await batch_service.update_file_counts(file["file_id"]) # TEMPORARY: awaiting bug fix for rate limits - await asyncio.sleep(5) + # await asyncio.sleep(5) except UnicodeDecodeError as ucde: logger.error("Error decoding file: %s", file) logger.error("Error decoding file. %s", ucde)