diff --git a/src/api/session_routes.py b/src/api/session_routes.py index 500993ae..20c62271 100644 --- a/src/api/session_routes.py +++ b/src/api/session_routes.py @@ -158,7 +158,7 @@ async def get_agent_messages( processed_events = [] for event in events: - event_dict = event.dict() + event_dict = event.model_dump() def process_dict(d): if isinstance(d, dict): diff --git a/src/schemas/schemas.py b/src/schemas/schemas.py index 5a3945c7..3649bdd7 100644 --- a/src/schemas/schemas.py +++ b/src/schemas/schemas.py @@ -27,7 +27,7 @@ └──────────────────────────────────────────────────────────────────────────────┘ """ -from pydantic import BaseModel, Field, validator, UUID4, ConfigDict +from pydantic import BaseModel, Field, field_validator, model_validator, UUID4, ConfigDict from typing import Optional, Dict, Any, List from datetime import datetime from uuid import UUID @@ -40,7 +40,8 @@ class ClientBase(BaseModel): name: str email: Optional[str] = None - @validator("email") + @field_validator("email") + @classmethod def validate_email(cls, v): if v is None: return v @@ -58,8 +59,7 @@ class Client(ClientBase): id: UUID created_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ApiKeyBase(BaseModel): @@ -101,7 +101,7 @@ class AgentBase(BaseModel): description="Agent type (llm, sequential, parallel, loop, a2a, workflow, task)", ) model: Optional[str] = Field( - None, description="Agent model (required only for llm type)" + None, description="LLM model identifier (required for LLM agents only)" ) api_key_id: Optional[UUID4] = Field( None, description="Reference to a stored API Key ID" @@ -115,8 +115,13 @@ class AgentBase(BaseModel): ) config: Any = Field(None, description="Agent configuration based on type") - @validator("name") - def validate_name(cls, v, values): + @field_validator("name") + @classmethod + def validate_name(cls, v, info): + # Get values from validation context + values = info.data if hasattr(info, 'data') else {} + + # A2A agents can have optional names if values.get("type") == "a2a": return v @@ -127,107 +132,246 @@ def validate_name(cls, v, values): raise ValueError("Agent name cannot contain spaces or special characters") return v - @validator("type") + @field_validator("type") + @classmethod def validate_type(cls, v): - if v not in [ + valid_types = [ "llm", - "sequential", - "parallel", - "loop", - "a2a", - "workflow", - "task", - ]: + "sequential", + "parallel", + "loop", + "a2a", + "workflow", + "task" + ] + if v not in valid_types: raise ValueError( - "Invalid agent type. Must be: llm, sequential, parallel, loop, a2a, workflow or task" + f"Invalid agent type '{v}'. Must be one of: {', '.join(valid_types)}" ) return v - @validator("agent_card_url") - def validate_agent_card_url(cls, v, values): - if "type" in values and values["type"] == "a2a": + @field_validator("agent_card_url") + @classmethod + def validate_agent_card_url(cls, v, info): + values = info.data if hasattr(info, 'data') else {} + + if values.get("type") == "a2a": if not v: raise ValueError("agent_card_url is required for a2a type agents") if not v.endswith("/.well-known/agent.json"): raise ValueError("agent_card_url must end with /.well-known/agent.json") return v - @validator("model") - def validate_model(cls, v, values): - if "type" in values and values["type"] == "llm" and not v: - raise ValueError("Model is required for llm type agents") + @field_validator("model") + @classmethod + def validate_model(cls, v, info): + values = info.data if hasattr(info, 'data') else {} + agent_type = values.get("type") + + if agent_type == "llm": + # Para agentes LLM, o modelo é obrigatório e não pode ser vazio + if not v or (isinstance(v, str) and v.strip() == ""): + raise ValueError( + "LLM agents require a valid model configuration. " + "Please specify a model identifier (e.g., 'gpt-4', 'claude-3-sonnet', 'gemini-pro')" + ) + + # Verificar se o modelo tem um formato válido + if isinstance(v, str) and len(v.strip()) < 3: + raise ValueError("Model identifier must be at least 3 characters long") + + elif agent_type in ["workflow", "task", "sequential", "parallel", "loop"]: + # Para estes tipos, não devem ter modelo + if v and (isinstance(v, str) and v.strip()): + # Avisar mas permitir (será removido durante a criação) + import logging + logger = logging.getLogger(__name__) + logger.warning(f"{agent_type} agents don't need model configuration. Model will be ignored.") + return v - @validator("api_key_id") - def validate_api_key_id(cls, v, values): + @field_validator("api_key_id") + @classmethod + def validate_api_key_id(cls, v, info): + values = info.data if hasattr(info, 'data') else {} + agent_type = values.get("type") + + # API key é obrigatório para agentes LLM (a menos que esteja na config) + if agent_type == "llm" and not v: + # Verificar se tem API key na config + config = values.get("config", {}) + if not config or not config.get("api_key"): + # Não falhar aqui, deixar a validação para o momento da criação + pass + return v - @validator("config") - def validate_config(cls, v, values): - if "type" in values and values["type"] == "a2a": + @field_validator("config") + @classmethod + def validate_config(cls, v, info): + values = info.data if hasattr(info, 'data') else {} + agent_type = values.get("type") + + if not agent_type: + return v + + # A2A agents têm config opcional + if agent_type == "a2a": return v or {} - if "type" not in values: + # Workflow agents têm config específico para workflow + if agent_type == "workflow": + if v and isinstance(v, dict): + if not v.get("workflow"): + raise ValueError("Workflow agents must have 'workflow' configuration") return v - # For workflow agents, we do not perform any validation - if "type" in values and values["type"] == "workflow": - return v + # Config é obrigatório para outros tipos (exceto a2a) + if not v and agent_type not in ["a2a"]: + raise ValueError( + f"Configuration is required for {agent_type} agent type" + ) + + # Validação específica por tipo + if agent_type == "llm": + return cls._validate_llm_config(v) + elif agent_type in ["sequential", "parallel", "loop"]: + return cls._validate_composite_config(v, agent_type) + elif agent_type == "task": + return cls._validate_task_config(v) - if not v and values.get("type") != "a2a": + return v + + @classmethod + def _validate_llm_config(cls, v): + """Valida configuração para agentes LLM""" + if isinstance(v, dict): + try: + # Convert the dictionary to LLMConfig + v = LLMConfig(**v) + except Exception as e: + raise ValueError(f"Invalid LLM configuration: {str(e)}") + elif not isinstance(v, LLMConfig): + raise ValueError("Invalid LLM configuration format") + return v + + @classmethod + def _validate_composite_config(cls, v, agent_type): + """Valida configuração para agentes compostos (sequential, parallel, loop)""" + if not isinstance(v, dict): + raise ValueError(f'Configuration for {agent_type} agent must be a dictionary') + + if "sub_agents" not in v: + raise ValueError(f'{agent_type} agents must have sub_agents configuration') + + if not isinstance(v["sub_agents"], list): + raise ValueError("sub_agents must be a list") + + if not v["sub_agents"]: raise ValueError( - f"Configuration is required for {values.get('type')} agent type" + f'{agent_type} agents must have at least one sub-agent' ) + + # Validação específica para LoopAgent + if agent_type == "loop": + max_iterations = v.get("max_iterations", 5) + if not isinstance(max_iterations, int) or max_iterations <= 0: + raise ValueError("max_iterations must be a positive integer") + + return v - if values["type"] == "llm": - if isinstance(v, dict): - try: - # Convert the dictionary to LLMConfig - v = LLMConfig(**v) - except Exception as e: - raise ValueError(f"Invalid LLM configuration for agent: {str(e)}") - elif not isinstance(v, LLMConfig): - raise ValueError("Invalid LLM configuration for agent") - elif values["type"] in ["sequential", "parallel", "loop"]: - if not isinstance(v, dict): - raise ValueError(f'Invalid configuration for agent {values["type"]}') - if "sub_agents" not in v: - raise ValueError(f'Agent {values["type"]} must have sub_agents') + @classmethod + def _validate_task_config(cls, v): + """Valida configuração para agentes de task""" + if not isinstance(v, dict): + raise ValueError('Configuration for task agent must be a dictionary') + + if "tasks" not in v: + raise ValueError('Task agents must have tasks configuration') + + if not isinstance(v["tasks"], list): + raise ValueError("tasks must be a list") + + if not v["tasks"]: + raise ValueError('Task agents must have at least one task') + + # Validar cada task individualmente + for i, task in enumerate(v["tasks"]): + if not isinstance(task, dict): + raise ValueError(f"Task {i+1} must be a dictionary") + + required_fields = ["agent_id", "description", "expected_output"] + for field in required_fields: + if field not in task: + raise ValueError(f"Task {i+1} missing required field: {field}") + + # Verificar se os campos não estão vazios + if not task[field] or (isinstance(task[field], str) and not task[field].strip()): + raise ValueError(f"Task {i+1} field '{field}' cannot be empty") + + # Validar sub_agents se presente + if "sub_agents" in v and v["sub_agents"] is not None: if not isinstance(v["sub_agents"], list): raise ValueError("sub_agents must be a list") - if not v["sub_agents"]: - raise ValueError( - f'Agent {values["type"]} must have at least one sub-agent' - ) - elif values["type"] == "task": - if not isinstance(v, dict): - raise ValueError(f'Invalid configuration for agent {values["type"]}') - if "tasks" not in v: - raise ValueError(f'Agent {values["type"]} must have tasks') - if not isinstance(v["tasks"], list): - raise ValueError("tasks must be a list") - if not v["tasks"]: - raise ValueError(f'Agent {values["type"]} must have at least one task') - for task in v["tasks"]: - if not isinstance(task, dict): - raise ValueError("Each task must be a dictionary") - required_fields = ["agent_id", "description", "expected_output"] - for field in required_fields: - if field not in task: - raise ValueError(f"Task missing required field: {field}") - - if "sub_agents" in v and v["sub_agents"] is not None: - if not isinstance(v["sub_agents"], list): - raise ValueError("sub_agents must be a list") - - return v return v + @model_validator(mode='after') + def validate_agent_consistency(self): + """Validação cruzada entre campos do agente""" + + # Verificar consistência entre tipo e configurações + if self.type == "llm": + # LLM agents devem ter modelo + if not self.model or (isinstance(self.model, str) and self.model.strip() == ""): + raise ValueError("LLM agents must have a valid model") + + # LLM agents devem ter API key (na config ou api_key_id) + has_api_key = bool(self.api_key_id) + if not has_api_key and self.config: + config_dict = self.config if isinstance(self.config, dict) else self.config.__dict__ + has_api_key = bool(config_dict.get("api_key")) + + if not has_api_key: + raise ValueError("LLM agents must have an API key configured") + + elif self.type in ["workflow", "task", "sequential", "parallel", "loop"]: + # Orchestrator agents não devem ter modelo + if self.model and isinstance(self.model, str) and self.model.strip(): + import logging + logger = logging.getLogger(__name__) + logger.warning(f"{self.type} agents don't need model configuration. Clearing model.") + self.model = None + + elif self.type == "a2a": + # A2A agents devem ter agent_card_url + if not self.agent_card_url: + raise ValueError("A2A agents must have agent_card_url") + + return self + class AgentCreate(AgentBase): client_id: UUID + @model_validator(mode='after') + def validate_creation_requirements(self): + """Validações específicas para criação de agentes""" + + # Chamar validação da classe pai + super().validate_agent_consistency() + + # Validações específicas para criação + if self.type == "llm": + # Para criação, ser mais rigoroso com modelo + if not self.model or len(self.model.strip()) < 3: + raise ValueError( + "LLM agents require a valid model identifier (minimum 3 characters). " + "Examples: 'gpt-4', 'claude-3-sonnet', 'gemini-pro'" + ) + + return self + class Agent(AgentBase): id: UUID @@ -237,17 +381,17 @@ class Agent(AgentBase): agent_card_url: Optional[str] = None folder_id: Optional[UUID4] = None - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) - @validator("agent_card_url", pre=True) - def set_agent_card_url(cls, v, values): + @field_validator("agent_card_url", mode='before') + @classmethod + def set_agent_card_url(cls, v, info): if v: return v + values = info.data if hasattr(info, 'data') else {} if "id" in values: from os import getenv - return f"{getenv('API_URL', '')}/api/v1/a2a/{values['id']}/.well-known/agent.json" return v @@ -262,6 +406,7 @@ class ToolConfig(BaseModel): inputModes: List[str] = Field(default_factory=list) outputModes: List[str] = Field(default_factory=list) + # Last edited by Arley Peter on 2025-05-17 class MCPServerBase(BaseModel): name: str @@ -272,6 +417,29 @@ class MCPServerBase(BaseModel): tools: Optional[List[ToolConfig]] = Field(default_factory=list) type: str = Field(default="official") + @field_validator("name") + @classmethod + def validate_name(cls, v): + if not v or not v.strip(): + raise ValueError("MCP Server name cannot be empty") + return v.strip() + + @field_validator("config_type") + @classmethod + def validate_config_type(cls, v): + valid_types = ["studio", "sse"] + if v not in valid_types: + raise ValueError(f"config_type must be one of: {valid_types}") + return v + + @field_validator("type") + @classmethod + def validate_type(cls, v): + valid_types = ["official", "community"] + if v not in valid_types: + raise ValueError(f"type must be one of: {valid_types}") + return v + class MCPServerCreate(MCPServerBase): pass @@ -282,8 +450,7 @@ class MCPServer(MCPServerBase): created_at: datetime updated_at: Optional[datetime] = None - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ToolBase(BaseModel): @@ -292,6 +459,13 @@ class ToolBase(BaseModel): config_json: Dict[str, Any] = Field(default_factory=dict) environments: Dict[str, Any] = Field(default_factory=dict) + @field_validator("name") + @classmethod + def validate_name(cls, v): + if not v or not v.strip(): + raise ValueError("Tool name cannot be empty") + return v.strip() + class ToolCreate(ToolBase): pass @@ -302,14 +476,20 @@ class Tool(ToolBase): created_at: datetime updated_at: Optional[datetime] = None - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class AgentFolderBase(BaseModel): name: str description: Optional[str] = None + @field_validator("name") + @classmethod + def validate_name(cls, v): + if not v or not v.strip(): + raise ValueError("Folder name cannot be empty") + return v.strip() + class AgentFolderCreate(AgentFolderBase): client_id: UUID4 @@ -326,3 +506,86 @@ class AgentFolder(AgentFolderBase): updated_at: Optional[datetime] = None model_config = ConfigDict(from_attributes=True) + + +class AgentTypeInfo(BaseModel): + """Informações sobre tipos de agente válidos""" + type: str + requires_model: bool + requires_config: bool + description: str + + @classmethod + def get_valid_types(cls) -> Dict[str, 'AgentTypeInfo']: + """Retorna informações sobre todos os tipos válidos de agente""" + return { + "llm": cls( + type="llm", + requires_model=True, + requires_config=True, + description="Large Language Model agent - requires model and API key" + ), + "workflow": cls( + type="workflow", + requires_model=False, + requires_config=True, + description="Workflow orchestrator agent - uses LangGraph for complex flows" + ), + "task": cls( + type="task", + requires_model=False, + requires_config=True, + description="Task management agent - coordinates multiple tasks" + ), + "sequential": cls( + type="sequential", + requires_model=False, + requires_config=True, + description="Sequential execution agent - runs sub-agents in order" + ), + "parallel": cls( + type="parallel", + requires_model=False, + requires_config=True, + description="Parallel execution agent - runs sub-agents concurrently" + ), + "loop": cls( + type="loop", + requires_model=False, + requires_config=True, + description="Loop execution agent - repeats sub-agents with conditions" + ), + "a2a": cls( + type="a2a", + requires_model=False, + requires_config=False, + description="Agent-to-Agent communication - external agent integration" + ) + } + + +class ModelValidationResult(BaseModel): + """Resultado da validação de modelo""" + is_valid: bool + error_message: Optional[str] = None + warnings: List[str] = Field(default_factory=list) + + @classmethod + def success(cls, warnings: List[str] = None) -> 'ModelValidationResult': + return cls(is_valid=True, warnings=warnings or []) + + @classmethod + def failure(cls, error_message: str) -> 'ModelValidationResult': + return cls(is_valid=False, error_message=error_message) + + +class AgentValidationSummary(BaseModel): + """Resumo de validação de agente""" + agent_id: Optional[UUID] = None + agent_name: Optional[str] = None + agent_type: str + is_valid: bool + model_validation: ModelValidationResult + config_validation: ModelValidationResult + general_errors: List[str] = Field(default_factory=list) + warnings: List[str] = Field(default_factory=list) \ No newline at end of file diff --git a/src/services/adk/agent_builder.py b/src/services/adk/agent_builder.py index c461bafd..02a359e5 100644 --- a/src/services/adk/agent_builder.py +++ b/src/services/adk/agent_builder.py @@ -67,15 +67,39 @@ async def _agent_tools_builder(self, agent) -> List[AgentTool]: if agent_tools_ids and isinstance(agent_tools_ids, list): for agent_tool_id in agent_tools_ids: sub_agent = get_agent(self.db, agent_tool_id) - llm_agent, _ = await self.build_llm_agent(sub_agent) - if llm_agent: - agent_tools.append(AgentTool(agent=llm_agent)) + if sub_agent: + # Verificar se o sub_agent é do tipo LLM antes de criar LlmAgent + if sub_agent.type == "llm": + llm_agent, _ = await self.build_llm_agent(sub_agent) + if llm_agent: + agent_tools.append(AgentTool(agent=llm_agent)) + else: + logger.warning(f"Agent tool {agent_tool_id} is not of type 'llm', skipping") + else: + logger.warning(f"Agent tool {agent_tool_id} not found") return agent_tools + def _validate_llm_agent_model(self, agent) -> None: + """Validate that LLM agent has a proper model configuration.""" + if not hasattr(agent, 'model') or not agent.model: + logger.error(f"LLM agent {agent.name} does not have a model configured") + raise ValueError(f"LLM agent {agent.name} requires a model configuration") + + if isinstance(agent.model, str) and agent.model.strip() == "": + logger.error(f"LLM agent {agent.name} has an empty model string") + raise ValueError(f"LLM agent {agent.name} has an empty model configuration") + + logger.info(f"Model validation passed for agent {agent.name}: {agent.model}") + async def _create_llm_agent( self, agent, enabled_tools: List[str] = [] ) -> Tuple[LlmAgent, Optional[AsyncExitStack]]: """Create an LLM agent from the agent data.""" + + self._validate_llm_agent_model(agent) + + logger.info(f"Creating LLM agent: {agent.name} with model: {agent.model}") + # Get custom tools from the configuration custom_tools = [] custom_tools = self.custom_tool_builder.build_tools(agent.config) @@ -110,7 +134,7 @@ async def _create_llm_agent( current_day_of_week=current_day_of_week, current_date_iso=current_date_iso, current_time=current_time, - ) + ) if agent.instruction else "" # add role on beginning of the prompt if agent.role: @@ -170,21 +194,27 @@ async def _create_llm_agent( f"Agent {agent.name} does not have a configured API key" ) - return ( - LlmAgent( + if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""): + raise ValueError(f"Cannot create LiteLlm with empty model for agent {agent.name}") + + try: + llm_agent = LlmAgent( name=agent.name, model=LiteLlm(model=agent.model, api_key=api_key), instruction=formatted_prompt, description=agent.description, tools=all_tools, - ), - mcp_exit_stack, - ) + ) + logger.info(f"LLM agent created successfully: {agent.name}") + return llm_agent, mcp_exit_stack + except Exception as e: + logger.error(f"Error creating LLM agent {agent.name}: {str(e)}") + raise ValueError(f"Error creating LLM agent {agent.name}: {str(e)}") async def _get_sub_agents( self, sub_agent_ids: List[str] - ) -> List[Tuple[LlmAgent, Optional[AsyncExitStack]]]: - """Get and create LLM sub-agents.""" + ) -> List[Tuple[BaseAgent, Optional[AsyncExitStack]]]: + """Get and create sub-agents with proper type validation.""" sub_agents = [] for sub_agent_id in sub_agent_ids: sub_agent_id_str = str(sub_agent_id) @@ -197,39 +227,50 @@ async def _get_sub_agents( logger.info(f"Sub-agent found: {agent.name} (type: {agent.type})") - if agent.type == "llm": - sub_agent, exit_stack = await self._create_llm_agent(agent) - elif agent.type == "a2a": - sub_agent, exit_stack = await self.build_a2a_agent(agent) - elif agent.type == "workflow": - sub_agent, exit_stack = await self.build_workflow_agent(agent) - elif agent.type == "task": - sub_agent, exit_stack = await self.build_task_agent(agent) - elif agent.type == "sequential": - sub_agent, exit_stack = await self.build_composite_agent(agent) - elif agent.type == "parallel": - sub_agent, exit_stack = await self.build_composite_agent(agent) - elif agent.type == "loop": - sub_agent, exit_stack = await self.build_composite_agent(agent) - else: - raise ValueError(f"Invalid agent type: {agent.type}") - - sub_agents.append(sub_agent) - logger.info(f"Sub-agent added: {agent.name}") + try: + if agent.type == "llm": + # Verificar se tem modelo antes de criar + if not agent.model or (isinstance(agent.model, str) and agent.model.strip() == ""): + logger.error(f"LLM sub-agent {agent.name} does not have a model configured") + raise ValueError(f"LLM sub-agent {agent.name} requires a model configuration") + sub_agent, exit_stack = await self._create_llm_agent(agent) + elif agent.type == "a2a": + sub_agent, exit_stack = await self.build_a2a_agent(agent) + elif agent.type == "workflow": + # Workflow agents não precisam de modelo + sub_agent, exit_stack = await self.build_workflow_agent(agent) + elif agent.type == "task": + sub_agent, exit_stack = await self.build_task_agent(agent) + elif agent.type == "sequential": + sub_agent, exit_stack = await self.build_composite_agent(agent) + elif agent.type == "parallel": + sub_agent, exit_stack = await self.build_composite_agent(agent) + elif agent.type == "loop": + sub_agent, exit_stack = await self.build_composite_agent(agent) + else: + raise ValueError(f"Invalid agent type: {agent.type}") + + sub_agents.append((sub_agent, exit_stack)) + logger.info(f"Sub-agent added: {agent.name}") + + except Exception as e: + logger.error(f"Error creating sub-agent {agent.name}: {str(e)}") + raise ValueError(f"Error creating sub-agent {agent.name}: {str(e)}") logger.info(f"Sub-agents created: {len(sub_agents)}") - logger.info(f"Sub-agents: {str(sub_agents)}") - return sub_agents async def build_llm_agent( self, root_agent, enabled_tools: List[str] = [] ) -> Tuple[LlmAgent, Optional[AsyncExitStack]]: """Build an LLM agent with its sub-agents.""" - logger.info("Creating LLM agent") + logger.info(f"Creating LLM agent: {root_agent.name}") + + if root_agent.type != "llm": + raise ValueError(f"Expected LLM agent, got {root_agent.type}") sub_agents = [] - if root_agent.config.get("sub_agents"): + if root_agent.config and root_agent.config.get("sub_agents"): sub_agents_with_stacks = await self._get_sub_agents( root_agent.config.get("sub_agents") ) @@ -241,20 +282,21 @@ async def build_llm_agent( if sub_agents: root_llm_agent.sub_agents = sub_agents + logger.info(f"LLM agent built successfully: {root_agent.name}") return root_llm_agent, exit_stack async def build_a2a_agent( self, root_agent - ) -> Tuple[BaseAgent, Optional[AsyncExitStack]]: + ) -> Tuple[A2ACustomAgent, Optional[AsyncExitStack]]: """Build an A2A agent with its sub-agents.""" - logger.info(f"Creating A2A agent from {root_agent.agent_card_url}") + logger.info(f"Creating A2A agent from {root_agent.name}") if not root_agent.agent_card_url: raise ValueError("agent_card_url is required for a2a agents") try: sub_agents = [] - if root_agent.config.get("sub_agents"): + if root_agent.config and root_agent.config.get("sub_agents"): sub_agents_with_stacks = await self._get_sub_agents( root_agent.config.get("sub_agents") ) @@ -288,6 +330,9 @@ async def build_workflow_agent( """Build a workflow agent with its sub-agents.""" logger.info(f"Creating Workflow agent from {root_agent.name}") + if root_agent.type != "workflow": + raise ValueError(f"Expected workflow agent, got {root_agent.type}") + agent_config = root_agent.config or {} if not agent_config.get("workflow"): @@ -295,7 +340,7 @@ async def build_workflow_agent( try: sub_agents = [] - if root_agent.config.get("sub_agents"): + if root_agent.config and root_agent.config.get("sub_agents"): sub_agents_with_stacks = await self._get_sub_agents( root_agent.config.get("sub_agents") ) @@ -304,15 +349,20 @@ async def build_workflow_agent( config = root_agent.config or {} timeout = config.get("timeout", 300) - workflow_agent = WorkflowAgent( - name=root_agent.name, - flow_json=agent_config.get("workflow"), - timeout=timeout, - description=root_agent.description - or f"Workflow Agent for {root_agent.name}", - sub_agents=sub_agents, - db=self.db, - ) + kwargs = { + "name": root_agent.name, + "flow_json": agent_config.get("workflow"), + "timeout": timeout, + "description": root_agent.description or f"Workflow Agent for {root_agent.name}", + "sub_agents": sub_agents, + "db": self.db, + } + + # Se o root_agent tiver modelo, não passá-lo para o WorkflowAgent + if hasattr(root_agent, 'model') and root_agent.model: + logger.warning(f"Workflow agent {root_agent.name} has model '{root_agent.model}' configured, but workflow agents should not have models. Ignoring model.") + + workflow_agent = WorkflowAgent(**kwargs) logger.info(f"Workflow agent created successfully: {root_agent.name}") @@ -328,6 +378,9 @@ async def build_task_agent( """Build a task agent with its sub-agents.""" logger.info(f"Creating Task agent: {root_agent.name}") + if root_agent.type != "task": + raise ValueError(f"Expected task agent, got {root_agent.type}") + agent_config = root_agent.config or {} if not agent_config.get("tasks"): @@ -336,7 +389,7 @@ async def build_task_agent( try: # Get sub-agents if there are any sub_agents = [] - if root_agent.config.get("sub_agents"): + if root_agent.config and root_agent.config.get("sub_agents"): sub_agents_with_stacks = await self._get_sub_agents( root_agent.config.get("sub_agents") ) @@ -380,7 +433,11 @@ async def build_composite_agent( f"Processing sub-agents for agent {root_agent.type} (ID: {root_agent.id}, Name: {root_agent.name})" ) - if not root_agent.config.get("sub_agents"): + valid_composite_types = ["sequential", "parallel", "loop"] + if root_agent.type not in valid_composite_types: + raise ValueError(f"Expected composite agent type ({valid_composite_types}), got {root_agent.type}") + + if not root_agent.config or not root_agent.config.get("sub_agents"): logger.error( f"Sub_agents configuration not found or empty for agent {root_agent.name}" ) @@ -401,39 +458,51 @@ async def build_composite_agent( sub_agents = [agent for agent, _ in sub_agents_with_stacks] logger.info(f"Extracted sub-agents: {[agent.name for agent in sub_agents]}") - if root_agent.type == "sequential": - logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents") - return ( - SequentialAgent( - name=root_agent.name, - sub_agents=sub_agents, - description=root_agent.config.get("description", ""), - ), - None, - ) - elif root_agent.type == "parallel": - logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents") - return ( - ParallelAgent( - name=root_agent.name, - sub_agents=sub_agents, - description=root_agent.config.get("description", ""), - ), - None, - ) - elif root_agent.type == "loop": - logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents") - return ( - LoopAgent( - name=root_agent.name, - sub_agents=sub_agents, - description=root_agent.config.get("description", ""), - max_iterations=root_agent.config.get("max_iterations", 5), - ), - None, - ) - else: - raise ValueError(f"Invalid agent type: {root_agent.type}") + if not sub_agents: + raise ValueError(f"No valid sub-agents found for {root_agent.type} agent {root_agent.name}") + + try: + if root_agent.type == "sequential": + logger.info(f"Creating SequentialAgent with {len(sub_agents)} sub-agents") + return ( + SequentialAgent( + name=root_agent.name, + sub_agents=sub_agents, + description=root_agent.description or root_agent.config.get("description", ""), + ), + None, + ) + elif root_agent.type == "parallel": + logger.info(f"Creating ParallelAgent with {len(sub_agents)} sub-agents") + return ( + ParallelAgent( + name=root_agent.name, + sub_agents=sub_agents, + description=root_agent.description or root_agent.config.get("description", ""), + ), + None, + ) + elif root_agent.type == "loop": + logger.info(f"Creating LoopAgent with {len(sub_agents)} sub-agents") + max_iterations = root_agent.config.get("max_iterations", 5) + if max_iterations <= 0: + logger.warning(f"Invalid max_iterations ({max_iterations}) for LoopAgent, using default 5") + max_iterations = 5 + return ( + LoopAgent( + name=root_agent.name, + sub_agents=sub_agents, + description=root_agent.description or root_agent.config.get("description", ""), + max_iterations=max_iterations, + ), + None, + ) + else: + raise ValueError(f"Invalid composite agent type: {root_agent.type}") + + except Exception as e: + logger.error(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}") + raise ValueError(f"Error creating {root_agent.type} agent {root_agent.name}: {str(e)}") async def build_agent(self, root_agent, enabled_tools: List[str] = []) -> Tuple[ LlmAgent @@ -446,13 +515,29 @@ async def build_agent(self, root_agent, enabled_tools: List[str] = []) -> Tuple[ Optional[AsyncExitStack], ]: """Build the appropriate agent based on the type of the root agent.""" - if root_agent.type == "llm": - return await self.build_llm_agent(root_agent, enabled_tools) - elif root_agent.type == "a2a": - return await self.build_a2a_agent(root_agent) - elif root_agent.type == "workflow": - return await self.build_workflow_agent(root_agent) - elif root_agent.type == "task": - return await self.build_task_agent(root_agent) - else: - return await self.build_composite_agent(root_agent) + + if not root_agent: + raise ValueError("root_agent cannot be None") + + if not hasattr(root_agent, 'type') or not root_agent.type: + raise ValueError("root_agent must have a valid type") + + logger.info(f"Building agent: {root_agent.name} (type: {root_agent.type})") + + try: + if root_agent.type == "llm": + return await self.build_llm_agent(root_agent, enabled_tools) + elif root_agent.type == "a2a": + return await self.build_a2a_agent(root_agent) + elif root_agent.type == "workflow": + return await self.build_workflow_agent(root_agent) + elif root_agent.type == "task": + return await self.build_task_agent(root_agent) + elif root_agent.type in ["sequential", "parallel", "loop"]: + return await self.build_composite_agent(root_agent) + else: + raise ValueError(f"Unknown agent type: {root_agent.type}") + + except Exception as e: + logger.error(f"Error building agent {root_agent.name}: {str(e)}") + raise \ No newline at end of file diff --git a/src/services/adk/agent_runner.py b/src/services/adk/agent_runner.py index ef727f5f..6daae26e 100644 --- a/src/services/adk/agent_runner.py +++ b/src/services/adk/agent_runner.py @@ -458,7 +458,7 @@ async def run_agent_stream( async for event in events_async: try: - event_dict = event.dict() + event_dict = event.model_dump() event_dict = convert_sets(event_dict) if "content" in event_dict and event_dict["content"]: diff --git a/src/services/adk/custom_agents/workflow_agent.py b/src/services/adk/custom_agents/workflow_agent.py index 18376347..1a2ebf6e 100644 --- a/src/services/adk/custom_agents/workflow_agent.py +++ b/src/services/adk/custom_agents/workflow_agent.py @@ -30,6 +30,7 @@ └──────────────────────────────────────────────────────────────────────────────┘ """ + from datetime import datetime from google.adk.agents import BaseAgent from google.adk.agents.invocation_context import InvocationContext @@ -40,11 +41,14 @@ import uuid from src.services.agent_service import get_agent +from src.utils.logger import setup_logger from sqlalchemy.orm import Session from langgraph.graph import StateGraph, END +logger = setup_logger(__name__) + class State(TypedDict): content: List[Event] @@ -63,6 +67,9 @@ class WorkflowAgent(BaseAgent): This agent allows defining and executing complex workflows between multiple agents using LangGraph for orchestration. + + IMPORTANT: Workflow agents are orchestrators and should NOT have a model configured. + They delegate to sub-agents that have their own models. """ # Field declarations for Pydantic @@ -89,6 +96,21 @@ def __init__( sub_agents: List of sub-agents to be executed after the workflow agent db: Session """ + + # Workflow agents não devem ter modelos + if 'model' in kwargs: + logger.warning(f"Removing model from workflow agent {name}. Workflow agents should not have models.") + del kwargs['model'] + + if not flow_json: + raise ValueError(f"Workflow agent {name} requires flow_json configuration") + + if not isinstance(flow_json, dict): + raise ValueError(f"Workflow agent {name} flow_json must be a dictionary") + + if not flow_json.get('nodes'): + raise ValueError(f"Workflow agent {name} flow_json must contain nodes") + # Initialize base class super().__init__( name=name, @@ -98,9 +120,13 @@ def __init__( db=db, **kwargs, ) + + if hasattr(self, 'model'): + logger.warning(f"Workflow agent {name} had a model attribute. Removing it.") + delattr(self, 'model') - print( - f"Workflow agent initialized with {len(flow_json.get('nodes', []))} nodes" + logger.info( + f"Workflow agent '{name}' initialized with {len(flow_json.get('nodes', []))} nodes" ) async def _create_node_functions(self, ctx: InvocationContext): @@ -112,11 +138,12 @@ async def start_node_function( node_id: str, node_data: Dict[str, Any], ) -> AsyncGenerator[State, None]: - print("\n🏁 INITIAL NODE") + logger.info(f"🏁 INITIAL NODE: {node_id}") content = state.get("content", []) if not content: + logger.warning("No content found in initial state") content = [ Event( author=f"workflow-node:{node_id}", @@ -128,9 +155,11 @@ async def start_node_function( "status": "error", "node_outputs": {}, "cycle_count": 0, - "conversation_history": ctx.session.events, + "conversation_history": ctx.session.events if ctx.session else [], + "session_id": state.get("session_id", ""), } return + session_id = state.get("session_id", "") # Store specific results for this node @@ -149,7 +178,7 @@ async def start_node_function( "node_outputs": node_outputs, "cycle_count": 0, "session_id": session_id, - "conversation_history": ctx.session.events, + "conversation_history": ctx.session.events if ctx.session else [], } # Generic function for agent nodes @@ -163,7 +192,7 @@ async def agent_node_function( # Increment cycle counter cycle_count = state.get("cycle_count", 0) + 1 - print(f"\n👤 AGENT: {agent_name} (Cycle {cycle_count})") + logger.info(f"👤 AGENT: {agent_name} (Cycle {cycle_count})") content = state.get("content", []) session_id = state.get("session_id", "") @@ -171,14 +200,32 @@ async def agent_node_function( # Get conversation history conversation_history = state.get("conversation_history", []) + if not agent_id: + logger.error(f"Agent node {node_id} does not have a valid agent_id") + yield { + "content": [ + Event( + author=f"workflow-node:{node_id}", + content=Content(parts=[Part(text="Agent ID not configured")]), + ) + ], + "session_id": session_id, + "status": "error", + "node_outputs": {}, + "cycle_count": cycle_count, + "conversation_history": conversation_history, + } + return + agent = get_agent(self.db, agent_id) if not agent: + logger.error(f"Agent not found for ID: {agent_id}") yield { "content": [ Event( author=f"workflow-node:{node_id}", - content=Content(parts=[Part(text="Agent not found")]), + content=Content(parts=[Part(text=f"Agent not found: {agent_id}")]), ) ], "session_id": session_id, @@ -189,44 +236,65 @@ async def agent_node_function( } return - # Import moved to inside the function to avoid circular import - from src.services.adk.agent_builder import AgentBuilder + try: + # Import moved to inside the function to avoid circular import + from src.services.adk.agent_builder import AgentBuilder - agent_builder = AgentBuilder(self.db) - root_agent, exit_stack = await agent_builder.build_agent(agent) + agent_builder = AgentBuilder(self.db) + root_agent, exit_stack = await agent_builder.build_agent(agent) - new_content = [] - async for event in root_agent.run_async(ctx): - conversation_history.append(event) - - modified_event = Event( - author=f"workflow-node:{node_id}", content=event.content - ) - new_content.append(modified_event) + new_content = [] + async for event in root_agent.run_async(ctx): + conversation_history.append(event) + + modified_event = Event( + author=f"workflow-node:{node_id}", content=event.content + ) + new_content.append(modified_event) + logger.debug(f"Agent {agent_name} generated {len(new_content)} events") - print(f"New content: {new_content}") + node_outputs = state.get("node_outputs", {}) + node_outputs[node_id] = { + "processed_by": agent_name, + "agent_id": agent_id, + "agent_content": new_content, + "cycle": cycle_count, + "processed_at": datetime.now().isoformat(), + } - node_outputs = state.get("node_outputs", {}) - node_outputs[node_id] = { - "processed_by": agent_name, - "agent_content": new_content, - "cycle": cycle_count, - } + content = content + new_content - content = content + new_content + yield { + "content": content, + "status": "processed_by_agent", + "node_outputs": node_outputs, + "cycle_count": cycle_count, + "conversation_history": conversation_history, + "session_id": session_id, + } - yield { - "content": content, - "status": "processed_by_agent", - "node_outputs": node_outputs, - "cycle_count": cycle_count, - "conversation_history": conversation_history, - "session_id": session_id, - } + if exit_stack: + try: + await exit_stack.aclose() + except Exception as e: + logger.warning(f"Error closing exit stack for agent {agent_name}: {e}") - if exit_stack: - await exit_stack.aclose() + except Exception as e: + logger.error(f"Error executing agent {agent_name}: {str(e)}") + yield { + "content": [ + Event( + author=f"workflow-node:{node_id}", + content=Content(parts=[Part(text=f"Error executing agent: {str(e)}")]), + ) + ], + "session_id": session_id, + "status": "agent_error", + "node_outputs": state.get("node_outputs", {}), + "cycle_count": cycle_count, + "conversation_history": conversation_history, + } # Function for condition nodes async def condition_node_function( @@ -236,7 +304,7 @@ async def condition_node_function( conditions = node_data.get("conditions", []) cycle_count = state.get("cycle_count", 0) - print(f"\n🔄 CONDITION: {label} (Cycle {cycle_count})") + logger.info(f"🔄 CONDITION: {label} (Cycle {cycle_count})") content = state.get("content", []) conversation_history = state.get("conversation_history", []) @@ -245,16 +313,17 @@ async def condition_node_function( if content and len(content) > 0: for event in reversed(content): if ( - event.author != "agent" - or not hasattr(event.content, "parts") - or not event.content.parts + hasattr(event, 'author') and + event.author != "user" and + hasattr(event, 'content') and + hasattr(event.content, "parts") and + event.content.parts ): latest_event = event break + if latest_event: - print( - f"Evaluating condition only for the most recent event: '{latest_event}'" - ) + logger.debug(f"Evaluating condition for latest event from: {latest_event.author}") # Use only the most recent event for condition evaluation evaluation_state = state.copy() @@ -273,25 +342,24 @@ async def condition_node_function( operator = condition_data.get("operator") expected_value = condition_data.get("value") - print( - f" Checking if {field} {operator} '{expected_value}' (current value: '{evaluation_state.get(field, '')}')" + logger.debug( + f"Checking condition: {field} {operator} '{expected_value}'" ) + if self._evaluate_condition(condition, evaluation_state): conditions_met.append(condition_id) condition_details.append( f"{field} {operator} '{expected_value}' ✅" ) - print(f" ✅ Condition {condition_id} met!") + logger.info(f"✅ Condition {condition_id} met!") else: condition_details.append( f"{field} {operator} '{expected_value}' ❌" ) - # Check if the cycle reached the limit (extra security) - if cycle_count >= 10: - print( - f"⚠️ ATTENTION: Cycle limit reached ({cycle_count}). Forcing termination." - ) + max_cycles = 10 # Poderia vir da configuração + if cycle_count >= max_cycles: + logger.warning(f"Cycle limit reached ({cycle_count}). Forcing termination.") condition_content = [ Event( @@ -314,10 +382,10 @@ async def condition_node_function( node_outputs = state.get("node_outputs", {}) node_outputs[node_id] = { "condition_evaluated": label, - "content_evaluated": content, "conditions_met": conditions_met, "condition_details": condition_details, "cycle": cycle_count, + "evaluated_at": datetime.now().isoformat(), } # Prepare a more descriptive message about the conditions @@ -334,7 +402,8 @@ async def condition_node_function( ) ] ), - ) ] + ) + ] content = content + condition_content yield { @@ -353,7 +422,7 @@ async def message_node_function( message_type = message_data.get("type", "text") message_content = message_data.get("content", "") - print(f"\n💬 MESSAGE-NODE: {message_content}") + logger.info(f"💬 MESSAGE-NODE: {message_content}") content = state.get("content", []) session_id = state.get("session_id", "") @@ -371,6 +440,8 @@ async def message_node_function( node_outputs[node_id] = { "message_type": message_type, "message_content": message_content, + "label": label, + "processed_at": datetime.now().isoformat(), } yield { @@ -378,7 +449,8 @@ async def message_node_function( "status": "message_added", "node_outputs": node_outputs, "cycle_count": state.get("cycle_count", 0), - "conversation_history": conversation_history, "session_id": session_id, + "conversation_history": conversation_history, + "session_id": session_id, } async def delay_node_function( @@ -389,6 +461,10 @@ async def delay_node_function( delay_unit = delay_data.get("unit", "seconds") delay_description = delay_data.get("description", "") + if delay_value <= 0: + logger.warning(f"Invalid delay value: {delay_value}. Using 1 second.") + delay_value = 1 + # Convert to seconds based on unit delay_seconds = delay_value if delay_unit == "minutes": @@ -397,7 +473,7 @@ async def delay_node_function( delay_seconds = delay_value * 3600 label = node_data.get("label", "delay_node") - print(f"\n⏱️ DELAY-NODE: {delay_value} {delay_unit} - {delay_description}") + logger.info(f"⏱️ DELAY-NODE: {delay_value} {delay_unit} ({delay_seconds}s) - {delay_description}") content = state.get("content", []) session_id = state.get("session_id", "") @@ -409,13 +485,17 @@ async def delay_node_function( "delay_value": delay_value, "delay_unit": delay_unit, "delay_seconds": delay_seconds, + "delay_description": delay_description, "delay_start_time": datetime.now().isoformat(), } # Actually perform the delay import asyncio - await asyncio.sleep(delay_seconds) - + try: + await asyncio.sleep(delay_seconds) + except asyncio.CancelledError: + logger.warning(f"Delay in node {node_id} was cancelled") + # Continue execution even if delay was cancelled # Update node outputs with completion information node_outputs[node_id]["delay_end_time"] = datetime.now().isoformat() @@ -424,7 +504,8 @@ async def delay_node_function( yield { "content": content, "status": "delay_completed", - "node_outputs": node_outputs, "cycle_count": state.get("cycle_count", 0), + "node_outputs": node_outputs, + "cycle_count": state.get("cycle_count", 0), "conversation_history": conversation_history, "session_id": session_id, } @@ -452,7 +533,7 @@ def _evaluate_condition(self, condition: Dict[str, Any], state: State) -> bool: result = self._process_condition(operator, actual_value, expected_value) - print(f" Check '{operator}': {result}") + logger.debug(f"Condition check '{operator}': {result}") return result return False @@ -488,7 +569,7 @@ def _extract_text_from_events(self, events): if extracted_texts: joined_text = " ".join(extracted_texts) - print(f" Extracted text from events: '{joined_text[:100]}...'") + logger.debug(f"Extracted text from events: '{joined_text[:100]}...'") return joined_text return "" @@ -524,6 +605,7 @@ def _process_operator(self, operator, actual_value, actual_str, expected_str): elif operator in ["matches", "not_matches"]: return self._check_regex(operator, actual_str, expected_str) + logger.warning(f"Unknown operator: {operator}") return False def _check_definition(self, operator, actual_value): @@ -563,8 +645,8 @@ def _check_numeric(self, operator, actual_str, expected_str): else: # less_than_or_equal return actual_num <= expected_num except (ValueError, TypeError): - print( - f" Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'" + logger.warning( + f"Error converting values for numeric comparison: '{actual_str[:100]}...' and '{expected_str}'" ) return False @@ -579,7 +661,7 @@ def _check_regex(self, operator, actual_str, expected_str): else: # not_matches return not bool(pattern.search(actual_str)) except re.error: - print(f" Error in regular expression: '{expected_str}'") + logger.warning(f"Error in regular expression: '{expected_str}'") return ( operator == "not_matches" ) # Return True for not_matches, False for matches @@ -589,8 +671,8 @@ def _case_insensitive_comparison(self, expected_str, actual_str, operator): expected_lower = expected_str.lower() actual_lower = actual_str.lower() - print( - f" Comparison '{operator}' without case distinction: '{expected_lower}' in '{actual_lower[:100]}...'" + logger.debug( + f"Comparison '{operator}' case insensitive: '{expected_lower}' in '{actual_lower[:100]}...'" ) if operator == "contains": @@ -627,14 +709,13 @@ def _create_flow_router(self, flow_data: Dict[str, Any]): # Routing function for each specific node def create_router_for_node(node_id: str): def router(state: State) -> str: - print(f"Routing from node: {node_id}") + logger.debug(f"Routing from node: {node_id}") # Check if the cycle limit has been reached cycle_count = state.get("cycle_count", 0) - if cycle_count >= 10: - print( - f"⚠️ Cycle limit ({cycle_count}) reached. Finalizing the flow." - ) + max_cycles = 10 # Configurável + if cycle_count >= max_cycles: + logger.warning(f"Cycle limit ({cycle_count}) reached. Finalizing the flow.") return END # If it's a condition node, evaluate the conditions @@ -648,32 +729,29 @@ def router(state: State) -> str: if conditions_met: any_condition_met = True condition_id = conditions_met[0] - print( - f"Using stored condition evaluation result: Condition {condition_id} met." - ) + logger.debug(f"Using stored condition result: Condition {condition_id} met.") if ( node_id in edges_map and condition_id in edges_map[node_id] ): return edges_map[node_id][condition_id] else: - print( - "Using stored condition evaluation result: No conditions met." - ) + logger.debug("Using stored condition result: No conditions met.") else: + # Evaluate conditions for condition in conditions: condition_id = condition.get("id") # Get latest event for evaluation, ignoring condition node informational events content = state.get("content", []) - # Filter out events generated by condition nodes or informational messages + # Filter out events generated by condition nodes or that contain evaluation results filtered_content = [] for event in content: # Ignore events from condition nodes or that contain evaluation results if not hasattr(event, "author") or not ( - event.author.startswith("Condition") - or "Condition evaluated:" in str(event) + event.author.startswith("workflow-node:") and + "Condition evaluated:" in str(event) ): filtered_content.append(event) @@ -687,9 +765,7 @@ def router(state: State) -> str: if is_condition_met: any_condition_met = True - print( - f"Condition {condition_id} met. Moving to the next node." - ) + logger.debug(f"Condition {condition_id} met. Moving to next node.") # Find the connection that uses this condition_id as a handle if ( @@ -698,9 +774,7 @@ def router(state: State) -> str: ): return edges_map[node_id][condition_id] else: - print( - f"Condition {condition_id} not met. Continuing evaluation or using default path." - ) + logger.debug(f"Condition {condition_id} not met.") # If no condition is met, use the bottom-handle if available if not any_condition_met: @@ -708,14 +782,10 @@ def router(state: State) -> str: node_id in edges_map and "bottom-handle" in edges_map[node_id] ): - print( - "No condition met. Using default path (bottom-handle)." - ) + logger.debug("No condition met. Using default path (bottom-handle).") return edges_map[node_id]["bottom-handle"] else: - print( - "No condition met and no default path. Closing the flow." - ) + logger.debug("No condition met and no default path. Closing the flow.") return END # For regular nodes, simply follow the first available connection @@ -731,7 +801,7 @@ def router(state: State) -> str: return edges_map[node_id][first_handle] # If there is no output connection, close the flow - print(f"No output connection from node {node_id}. Closing the flow.") + logger.debug(f"No output connection from node {node_id}. Closing the flow.") return END return router @@ -745,6 +815,9 @@ async def _create_graph( # Extract nodes from the flow nodes = flow_data.get("nodes", []) + if not nodes: + raise ValueError("Flow data must contain at least one node") + # Initialize StateGraph graph_builder = StateGraph(State) @@ -754,34 +827,60 @@ async def _create_graph( # Dictionary to store specific functions for each node node_specific_functions = {} + valid_node_types = set(node_functions.keys()) + # Add nodes to the graph for node in nodes: node_id = node.get("id") node_type = node.get("type") node_data = node.get("data", {}) - if node_type in node_functions: - # Create a specific function for this node - def create_node_function(node_type, node_id, node_data): - async def node_function(state): - # Consume the asynchronous generator and return the last result - result = None + if not node_id: + logger.warning(f"Skipping node without ID: {node}") + continue + + if node_type not in valid_node_types: + logger.warning(f"Unknown node type '{node_type}' for node {node_id}. Skipping.") + continue + + # Create a specific function for this node + def create_node_function(node_type, node_id, node_data): + async def node_function(state): + # Consume the asynchronous generator and return the last result + result = None + try: async for item in node_functions[node_type]( state, node_id, node_data ): result = item return result + except Exception as e: + logger.error(f"Error in node {node_id} ({node_type}): {str(e)}") + # Return error state + return { + "content": [ + Event( + author=f"workflow-node:{node_id}", + content=Content(parts=[Part(text=f"Node error: {str(e)}")]), + ) + ], + "status": "node_error", + "node_outputs": state.get("node_outputs", {}), + "cycle_count": state.get("cycle_count", 0), + "conversation_history": state.get("conversation_history", []), + "session_id": state.get("session_id", ""), + } + + return node_function + + # Add specific function to the dictionary + node_specific_functions[node_id] = create_node_function( + node_type, node_id, node_data + ) - return node_function - - # Add specific function to the dictionary - node_specific_functions[node_id] = create_node_function( - node_type, node_id, node_data - ) - - # Add node to the graph - print(f"Adding node {node_id} of type {node_type}") - graph_builder.add_node(node_id, node_specific_functions[node_id]) + # Add node to the graph + logger.debug(f"Adding node {node_id} of type {node_type}") + graph_builder.add_node(node_id, node_specific_functions[node_id]) # Create function to generate specific routers create_router = self._create_flow_router(flow_data) @@ -808,8 +907,8 @@ async def node_function(state): node_router = create_router(node_id) # Add conditional connections - print(f"Adding conditional connections for node {node_id}") - print(f"Possible destinations: {edge_destinations}") + logger.debug(f"Adding conditional connections for node {node_id}") + logger.debug(f"Possible destinations: {list(edge_destinations.keys())}") graph_builder.add_conditional_edges( node_id, node_router, edge_destinations @@ -825,35 +924,56 @@ async def node_function(state): # If there is no start-node, use the first node found if not entry_point and nodes: entry_point = nodes[0].get("id") + logger.warning(f"No start-node found, using first node as entry point: {entry_point}") # Define the entry point if entry_point: - print(f"Defining entry point: {entry_point}") + logger.info(f"Setting entry point: {entry_point}") graph_builder.set_entry_point(entry_point) + else: + raise ValueError("No valid entry point found for workflow") # Compile the graph - return graph_builder.compile() + try: + compiled_graph = graph_builder.compile() + logger.info("Workflow graph compiled successfully") + return compiled_graph + except Exception as e: + logger.error(f"Error compiling workflow graph: {str(e)}") + raise ValueError(f"Error compiling workflow graph: {str(e)}") async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: """Implementation of the workflow agent executing the defined workflow and returning results.""" + + if hasattr(self, 'model') and self.model: + logger.error(f"Workflow agent {self.name} should not have a model configured") + raise ValueError(f"Workflow agent {self.name} is an orchestrator and should not have a model. Models should be configured on sub-agents.") + try: + logger.info(f"Starting workflow execution for agent: {self.name}") + logger.debug(f"Context session ID: {ctx.session.id if ctx.session else 'No session'}") + user_message = await self._extract_user_message(ctx) session_id = self._get_session_id(ctx) + + if not self.flow_json: + raise ValueError("Workflow agent has no flow_json configured") + graph = await self._create_graph(ctx, self.flow_json) initial_state = await self._prepare_initial_state( ctx, user_message, session_id ) - print("\n🚀 Starting workflow execution:") - print(f"Initial content: {user_message[:100]}...") + logger.info(f"🚀 Starting workflow execution with initial message: {user_message[:100]}...") # Iterar sobre o AsyncGenerator em vez de usar await async for event in self._execute_workflow(ctx, graph, initial_state): yield event except Exception as e: + logger.error(f"Error in workflow execution: {str(e)}", exc_info=True) yield await self._handle_workflow_error(e) async def _extract_user_message(self, ctx: InvocationContext) -> str: @@ -861,24 +981,36 @@ async def _extract_user_message(self, ctx: InvocationContext) -> str: # Try to find message in session events if ctx.session and hasattr(ctx.session, "events") and ctx.session.events: for event in reversed(ctx.session.events): - if event.author == "user" and event.content and event.content.parts: - print("Message found in session events") + if ( + hasattr(event, 'author') and + event.author == "user" and + hasattr(event, 'content') and + event.content and + hasattr(event.content, 'parts') and + event.content.parts + ): + logger.debug("User message found in session events") return event.content.parts[0].text # Try to find message in session state - if ctx.session and ctx.session.state: + if ctx.session and hasattr(ctx.session, 'state') and ctx.session.state: if "user_message" in ctx.session.state: return ctx.session.state["user_message"] elif "message" in ctx.session.state: return ctx.session.state["message"] - return "" + logger.warning("No user message found in context") + return "No user message provided" def _get_session_id(self, ctx: InvocationContext) -> str: """Gets or generates a session ID.""" - if ctx.session and hasattr(ctx.session, "id"): + if ctx.session and hasattr(ctx.session, "id") and ctx.session.id: return str(ctx.session.id) - return str(uuid.uuid4()) + + # Generate a new session ID + new_session_id = str(uuid.uuid4()) + logger.debug(f"Generated new session ID: {new_session_id}") + return new_session_id async def _prepare_initial_state( self, ctx: InvocationContext, user_message: str, session_id: str @@ -889,9 +1021,13 @@ async def _prepare_initial_state( content=Content(parts=[Part(text=user_message)]), ) - conversation_history = ctx.session.events or [user_event] + conversation_history = [] + if ctx.session and hasattr(ctx.session, 'events') and ctx.session.events: + conversation_history = ctx.session.events.copy() + else: + conversation_history = [user_event] - return State( + initial_state = State( content=[user_event], status="started", session_id=session_id, @@ -899,34 +1035,61 @@ async def _prepare_initial_state( node_outputs={}, conversation_history=conversation_history, ) + + logger.debug(f"Initial state prepared with {len(conversation_history)} history events") + return initial_state async def _execute_workflow( self, ctx: InvocationContext, graph: StateGraph, initial_state: State ) -> AsyncGenerator[Event, None]: """Executes the workflow graph and yields events.""" sent_events = 0 + total_iterations = 0 + max_iterations = 100 - async for state in graph.astream(initial_state, {"recursion_limit": 100}): - for node_state in state.values(): - content = node_state.get("content", []) - for event in content[sent_events:]: - if event.author != "user": - yield event - sent_events = len(content) + try: + async for state in graph.astream(initial_state, {"recursion_limit": max_iterations}): + total_iterations += 1 + + if total_iterations > max_iterations: + logger.warning(f"Maximum iterations ({max_iterations}) reached, stopping workflow") + break + + for node_state in state.values(): + content = node_state.get("content", []) + + # Yield new events that haven't been sent yet + for event in content[sent_events:]: + if hasattr(event, 'author') and event.author != "user": + yield event + + sent_events = len(content) + + logger.info(f"Workflow completed after {total_iterations} iterations") - # Execute sub-agents if any - for sub_agent in self.sub_agents: - async for event in sub_agent.run_async(ctx): - yield event + except Exception as e: + logger.error(f"Error during workflow execution: {str(e)}") + yield await self._handle_workflow_error(e) + + if self.sub_agents: + logger.info(f"Executing {len(self.sub_agents)} sub-agents") + for sub_agent in self.sub_agents: + try: + async for event in sub_agent.run_async(ctx): + yield event + except Exception as e: + logger.error(f"Error executing sub-agent {sub_agent.name}: {str(e)}") + yield await self._handle_workflow_error(e) async def _handle_workflow_error(self, error: Exception) -> Event: """Creates an error event for workflow execution errors.""" - error_msg = f"Error executing the workflow agent: {str(error)}" - print(error_msg) + error_msg = f"Error executing workflow agent '{self.name}': {str(error)}" + logger.error(error_msg) + return Event( author=f"workflow-error:{self.name}", content=Content( role="agent", parts=[Part(text=error_msg)], ), - ) + ) \ No newline at end of file