Skip to content

perf: Implement retry logic to improve resiliency #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/backend/sql_agents/convert_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
and updates the database with the results.
"""

import asyncio
import json
import logging

Expand Down Expand Up @@ -69,7 +68,7 @@ async def convert_script(
carry_response = None
async for response in chat.invoke():
# TEMPORARY: awaiting bug fix for rate limits
await asyncio.sleep(5)
#await asyncio.sleep(5)
carry_response = response
if response.role == AuthorRole.ASSISTANT.value:
# Our process can terminate with either of these as the last response
Expand Down
254 changes: 161 additions & 93 deletions src/backend/sql_agents/helpers/comms_manager.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,184 @@
"""Manages all agent communication and chat strategies for the SQL agents."""
"""Optimized CommsManager with parallel processing and performance improvements."""

from semantic_kernel.agents import AgentGroupChat # pylint: disable=E0611
import asyncio
import logging
import re
from typing import AsyncIterable, ClassVar, List
from concurrent.futures import ThreadPoolExecutor

from semantic_kernel.agents import AgentGroupChat
from semantic_kernel.agents.strategies import (
SequentialSelectionStrategy,
TerminationStrategy,
)
from semantic_kernel.contents import ChatMessageContent
from semantic_kernel.exceptions import AgentInvokeException

from sql_agents.agents.migrator.response import MigratorResponse
from sql_agents.helpers.models import AgentType


class CommsManager:
"""Manages all agent communication and selection strategies for the SQL agents."""

group_chat: AgentGroupChat = None
"""Optimized CommsManager with parallel processing and performance improvements."""

class SelectionStrategy(SequentialSelectionStrategy):
"""A strategy for determining which agent should take the next turn in the chat."""

# Select the next agent that should take the next turn in the chat
async def select_agent(self, agents, history):
"""Check which agent should take the next turn in the chat."""
match history[-1].name:
case AgentType.MIGRATOR.value:
# The Migrator should go first
agent_name = AgentType.PICKER.value
return next(
(agent for agent in agents if agent.name == agent_name), None
)
# The Incident Manager should go after the User or the Devops Assistant
case AgentType.PICKER.value:
agent_name = AgentType.SYNTAX_CHECKER.value
return next(
(agent for agent in agents if agent.name == agent_name), None
)
case AgentType.SYNTAX_CHECKER.value:
agent_name = AgentType.FIXER.value
return next(
(agent for agent in agents if agent.name == agent_name),
None,
)
case AgentType.FIXER.value:
# The Fixer should always go after the Syntax Checker
agent_name = AgentType.SYNTAX_CHECKER.value
return next(
(agent for agent in agents if agent.name == agent_name), None
)
case "candidate":
# The candidate message is created in the orchestration loop to pass the
# candidate and source sql queries to the Semantic Verifier
# It is created when the Syntax Checker returns an empty list of errors
agent_name = AgentType.SEMANTIC_VERIFIER.value
return next(
(agent for agent in agents if agent.name == agent_name),
None,
)
case _:
# Start run with this one - no history
return next(
(
agent
for agent in agents
if agent.name == AgentType.MIGRATOR.value
),
None,
)
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
_EXTRACT_WAIT_TIME = r"in (\d+) seconds"

# class for termination strategy
class ApprovalTerminationStrategy(TerminationStrategy):
"""
A strategy for determining when an agent should terminate.
This, combined with the maximum_iterations setting on the group chat, determines
when the agents are finished processing a file when there are no errors.
"""
def __init__(
self,
agent_dict: dict[AgentType, object],
exception_types: tuple = (Exception,),
max_retries: int = 3, # reduc from 10
initial_delay: float = 0.5, # reduced from 1.0
backoff_factor: float = 1.5, # reduced from 2.0
simple_truncation: int = 50, # more aggr truncation
batch_size: int = 10, # process in batches
max_workers: int = 4, # parallel processing
):
self.max_retries = max_retries
self.initial_delay = initial_delay
self.backoff_factor = backoff_factor
self.exception_types = exception_types
self.simple_truncation = simple_truncation
self.batch_size = batch_size
self.max_workers = max_workers

# Thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=max_workers)

async def should_agent_terminate(self, agent, history):
"""Check if the agent should terminate."""
# May need to convert to models to get usable content using history[-1].name
terminate: bool = False
lower_case_hist: str = history[-1].content.lower()
match history[-1].name:
case AgentType.MIGRATOR.value:
response = MigratorResponse.model_validate_json(
lower_case_hist or ""
)
if (
response.input_error is not None
or response.rai_error is not None
):
terminate = True
case AgentType.SEMANTIC_VERIFIER.value:
# Always terminate after the Semantic Verifier runs
terminate = True
case _:
# If the agent is not the Migrator or Semantic Verifier, don't terminate
# Note that the Syntax Checker and Fixer loop are only terminated by correct SQL
# or by iterations exceeding the max_iterations setting
pass

return terminate

def __init__(self, agent_dict):
"""Initialize the CommsManager and agent_chat with the given agents."""
self.group_chat = AgentGroupChat(
agents=agent_dict.values(),
termination_strategy=self.ApprovalTerminationStrategy(
termination_strategy=self.OptimizedTerminationStrategy(
agents=[
agent_dict[AgentType.MIGRATOR],
agent_dict[AgentType.SEMANTIC_VERIFIER],
],
maximum_iterations=10,
maximum_iterations=5, # Reduced from 10
automatic_reset=True,
),
selection_strategy=self.SelectionStrategy(agents=agent_dict.values()),
selection_strategy=self.ParallelSelectionStrategy(
agents=agent_dict.values(),
max_workers=max_workers
),
)

async def async_invoke_batch(self, inputs: List[str]) -> AsyncIterable[ChatMessageContent]:
"""Process multiple inputs in parallel batches."""
# Process inputs in batches
for i in range(0, len(inputs), self.batch_size):
batch = inputs[i:i + self.batch_size]

# Process batch in parallel
tasks = [self._process_single_input(input_item) for input_item in batch]
batch_results = await asyncio.gather(*tasks, return_exceptions=True)

for result in batch_results:
if isinstance(result, Exception):
self.logger.error(f"Batch processing error: {result}")
continue

async for item in result:
yield item

async def _process_single_input(self, input_item: str) -> AsyncIterable[ChatMessageContent]:
"""Process a single input with optimized retry logic."""
attempt = 0
current_delay = self.initial_delay

while attempt < self.max_retries:
try:
# Aggressive history truncation
if len(self.group_chat.history) > self.simple_truncation:
# Keep only the most recent messages
self.group_chat.history = self.group_chat.history[-self.simple_truncation:]

# Add input to chat
self.group_chat.add_chat_message(ChatMessageContent(
role="user",
content=input_item
))

async_iter = self.group_chat.invoke()
async for item in async_iter:
yield item
break

except AgentInvokeException as aie:
attempt += 1
if attempt >= self.max_retries:
self.logger.error(
"Input processing failed after %d attempts: %s",
self.max_retries, str(aie)
)
# Don't raise, continue with next input
break

# Faster retry with shorter delays
match = re.search(self._EXTRACT_WAIT_TIME, str(aie))
if match:
current_delay = min(int(match.group(1)), 5) # Cap at 5 seconds
else:
current_delay = min(current_delay * self.backoff_factor, 10) # Cap at 10 seconds

self.logger.warning(
"Attempt %d/%d failed. Retrying in %.2f seconds...",
attempt, self.max_retries, current_delay
)
await asyncio.sleep(current_delay)

class ParallelSelectionStrategy(SequentialSelectionStrategy):
"""Optimized selection strategy with parallel processing capabilities."""

def __init__(self, agents, max_workers: int = 4):
super().__init__(agents)
self.max_workers = max_workers

async def select_agent(self, agents, history):
"""Select agent with optimized logic and parallel processing hints."""
if not history:
return next((agent for agent in agents if agent.name == AgentType.MIGRATOR.value), None)

last_agent = history[-1].name

# Optimized selection logic with fewer transitions
agent_transitions = {
AgentType.MIGRATOR.value: AgentType.PICKER.value,
AgentType.PICKER.value: AgentType.SYNTAX_CHECKER.value,
AgentType.SYNTAX_CHECKER.value: AgentType.FIXER.value,
AgentType.FIXER.value: AgentType.SEMANTIC_VERIFIER.value, # Skip syntax check
"candidate": AgentType.SEMANTIC_VERIFIER.value,
}

next_agent_name = agent_transitions.get(last_agent, AgentType.MIGRATOR.value)
return next((agent for agent in agents if agent.name == next_agent_name), None)

class OptimizedTerminationStrategy(TerminationStrategy):
"""Optimized termination strategy with faster decision making."""

async def should_agent_terminate(self, agent, history):
"""Determine termination with optimized checks."""
if not history:
return False

last_message = history[-1]
lower_case_content = last_message.content.lower()

# Fast termination checks
if last_message.name == AgentType.SEMANTIC_VERIFIER.value:
return True

if last_message.name == AgentType.MIGRATOR.value:
try:
# Faster JSON parsing with error handling
response = MigratorResponse.model_validate_json(lower_case_content or "{}")
return bool(response.input_error or response.rai_error)
except Exception:
# If parsing fails, assume no termination needed
return False

return False

def cleanup(self):
"""Clean up resources."""
if hasattr(self, 'executor'):
self.executor.shutdown(wait=False)
3 changes: 1 addition & 2 deletions src/backend/sql_agents/process_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
It is the main entry point for the SQL migration process.
"""

import asyncio
import logging

from api.status_updates import send_status_update
Expand Down Expand Up @@ -132,7 +131,7 @@ async def process_batch_async(
else:
await batch_service.update_file_counts(file["file_id"])
# TEMPORARY: awaiting bug fix for rate limits
await asyncio.sleep(5)
#await asyncio.sleep(5)
except UnicodeDecodeError as ucde:
logger.error("Error decoding file: %s", file)
logger.error("Error decoding file. %s", ucde)
Expand Down
Loading