diff --git a/document_processor.py b/document_processor.py index e533d8a3..7e4cf7e3 100644 --- a/document_processor.py +++ b/document_processor.py @@ -24,7 +24,7 @@ TokenFilter, VoteFilter, ) -from services.workflows.vector_react import add_documents_to_vectors +from services.workflows.chat import add_documents_to_vectors # Load environment variables dotenv.load_dotenv() diff --git a/proposal_evaluation_test.py b/proposal_evaluation_test.py index 449b063f..517014eb 100644 --- a/proposal_evaluation_test.py +++ b/proposal_evaluation_test.py @@ -97,6 +97,9 @@ async def test_proposal_evaluation_workflow(): # Create a test proposal proposal_id = await create_test_proposal(dao_id) + # Use a consistent test wallet ID + test_wallet_id = UUID("532fd36b-8a9d-4fdd-82d2-25ddcf007488") + # Test scenarios scenarios = [ { @@ -107,7 +110,7 @@ async def test_proposal_evaluation_workflow(): }, { "name": "Auto-vote Enabled", - "auto_vote": False, # Fixed: Changed to True for auto-vote scenario + "auto_vote": True, # Corrected: Changed to True for auto-vote scenario "confidence_threshold": 0.7, "description": "Testing proposal evaluation with auto-voting", }, @@ -128,6 +131,7 @@ async def test_proposal_evaluation_workflow(): if scenario["auto_vote"]: result = await evaluate_and_vote_on_proposal( proposal_id=proposal_id, + wallet_id=test_wallet_id, # Add wallet_id for auto-vote scenarios auto_vote=scenario["auto_vote"], confidence_threshold=scenario["confidence_threshold"], dao_id=dao_id, @@ -135,7 +139,7 @@ async def test_proposal_evaluation_workflow(): else: result = await evaluate_proposal_only( proposal_id=proposal_id, - wallet_id=UUID("532fd36b-8a9d-4fdd-82d2-25ddcf007488"), + wallet_id=test_wallet_id, # Use the same consistent wallet ID ) # Print the results @@ -145,8 +149,13 @@ async def test_proposal_evaluation_workflow(): print(f"Approval: {result['evaluation']['approve']}") print(f"Confidence: {result['evaluation']['confidence_score']}") print(f"Reasoning: {result['evaluation']['reasoning']}") - print(f"Token Usage: {result['token_usage']}") - print(f"Cost: ${result['token_costs']['total_cost']:.4f}") + print( + f"Total Token Usage by Model: {result.get('total_token_usage_by_model')}" + ) + print(f"Total Cost by Model: {result.get('total_cost_by_model')}") + print( + f"Total Overall Cost: ${result.get('total_overall_cost', 0.0):.4f}" + ) if scenario["auto_vote"]: print(f"Auto-voted: {result['auto_voted']}") diff --git a/services/runner/tasks/dao_proposal_evaluation.py b/services/runner/tasks/dao_proposal_evaluation.py index d1f1dd6b..c11c2ece 100644 --- a/services/runner/tasks/dao_proposal_evaluation.py +++ b/services/runner/tasks/dao_proposal_evaluation.py @@ -127,8 +127,8 @@ async def process_message(self, message: QueueMessage) -> Dict[str, Any]: confidence = evaluation.get("confidence_score", 0.0) reasoning = evaluation.get("reasoning", "No reasoning provided") formatted_prompt = result.get("formatted_prompt", "No prompt provided") - total_cost = result.get("token_costs", {}).get("total_cost", 0.0) - model = result.get("model_info", {}).get("name", "Unknown") + total_cost = result.get("total_overall_cost", 0.0) + model = result.get("evaluation_model_info", {}).get("name", "Unknown") logger.info( f"Proposal {proposal.id} ({dao.name}): Evaluated with result " diff --git a/services/runner/tasks/dao_task.py b/services/runner/tasks/dao_task.py index 3f91e4e7..f33e0fbd 100644 --- a/services/runner/tasks/dao_task.py +++ b/services/runner/tasks/dao_task.py @@ -13,7 +13,7 @@ QueueMessageType, ) from lib.logger import configure_logger -from services.workflows import execute_langgraph_stream +from services.workflows import execute_workflow_stream from tools.tools_factory import filter_tools_by_names, initialize_tools from ..base import BaseTask, JobContext, RunnerConfig, RunnerResult @@ -181,7 +181,7 @@ async def _process_dao_message(self, message: QueueMessage) -> DAOProcessingResu logger.debug(f"DAO deployment parameters: {tool_input}") deployment_data = {} - async for chunk in execute_langgraph_stream( + async for chunk in execute_workflow_stream( history=[], input_str=tool_input, tools_map=self.tools_map ): if chunk["type"] == "result": diff --git a/services/schedule.py b/services/schedule.py index 0147033a..b172fcdb 100644 --- a/services/schedule.py +++ b/services/schedule.py @@ -10,7 +10,7 @@ from backend.models import JobBase, JobCreate, StepCreate, Task, TaskFilter from lib.logger import configure_logger from lib.persona import generate_persona -from services.workflows import execute_langgraph_stream +from services.workflows import execute_workflow_stream from tools.tools_factory import exclude_tools_by_names, initialize_tools logger = configure_logger(__name__) @@ -142,7 +142,7 @@ async def _process_job_stream( ["db_update_scheduled_task", "db_add_scheduled_task"], tools_map ) - stream_generator = execute_langgraph_stream( + stream_generator = execute_workflow_stream( history=history, input_str=task.prompt, persona=persona, diff --git a/services/workflows/__init__.py b/services/workflows/__init__.py index e2e72953..183c0607 100644 --- a/services/workflows/__init__.py +++ b/services/workflows/__init__.py @@ -6,20 +6,21 @@ BaseWorkflowMixin, ExecutionError, LangGraphError, - PlanningCapability, + MessageContent, + MessageProcessor, StateType, + StreamingCallbackHandler, StreamingError, ValidationError, - VectorRetrievalCapability, ) -# Enhanced ReAct workflow variants -from services.workflows.preplan_react import ( - PreplanLangGraphService, - PreplanReactWorkflow, - PreplanState, - execute_preplan_react_stream, +# Remove all imports from deleted files and import from chat.py +from services.workflows.chat import ( + ChatService, + ChatWorkflow, + execute_chat_stream, ) +from services.workflows.planning_mixin import PlanningCapability # Special purpose workflows from services.workflows.proposal_evaluation import ( @@ -30,15 +31,6 @@ # Core messaging and streaming components # Core ReAct workflow components -from services.workflows.react import ( - LangGraphService, - MessageContent, - MessageProcessor, - ReactState, - ReactWorkflow, - StreamingCallbackHandler, - execute_langgraph_stream, -) from services.workflows.tweet_analysis import ( TweetAnalysisWorkflow, analyze_tweet, @@ -47,19 +39,11 @@ TweetGeneratorWorkflow, generate_dao_tweet, ) -from services.workflows.vector_preplan_react import ( - VectorPreplanLangGraphService, - VectorPreplanReactWorkflow, - VectorPreplanState, - execute_vector_preplan_stream, -) -from services.workflows.vector_react import ( - VectorLangGraphService, - VectorReactState, - VectorReactWorkflow, +from services.workflows.vector_mixin import ( + VectorRetrievalCapability, add_documents_to_vectors, - execute_vector_langgraph_stream, ) +from services.workflows.web_search_mixin import WebSearchCapability # Workflow service and factory from services.workflows.workflow_service import ( @@ -76,7 +60,6 @@ "BaseWorkflowMixin", "ExecutionError", "LangGraphError", - "PlanningCapability", "StateType", "StreamingError", "ValidationError", @@ -96,22 +79,6 @@ "ReactState", "ReactWorkflow", "execute_langgraph_stream", - # PrePlan ReAct workflow - "PreplanLangGraphService", - "PreplanReactWorkflow", - "PreplanState", - "execute_preplan_react_stream", - # Vector ReAct workflow - "VectorLangGraphService", - "VectorReactState", - "VectorReactWorkflow", - "add_documents_to_vectors", - "execute_vector_langgraph_stream", - # Vector PrePlan ReAct workflow - "VectorPreplanLangGraphService", - "VectorPreplanReactWorkflow", - "VectorPreplanState", - "execute_vector_preplan_stream", # Special purpose workflows "ProposalEvaluationWorkflow", "TweetAnalysisWorkflow", @@ -120,4 +87,12 @@ "evaluate_and_vote_on_proposal", "evaluate_proposal_only", "generate_dao_tweet", + # Chat workflow + "ChatService", + "ChatWorkflow", + "execute_chat_stream", + # Mixins + "PlanningCapability", + "WebSearchCapability", + "add_documents_to_vectors", ] diff --git a/services/workflows/base.py b/services/workflows/base.py index 2259335e..1689b442 100644 --- a/services/workflows/base.py +++ b/services/workflows/base.py @@ -1,16 +1,19 @@ """Base workflow functionality and shared components for all workflow types.""" +import asyncio +import datetime import json +import uuid from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from langchain.prompts import PromptTemplate -from langchain.schema import Document -from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI from langgraph.graph import Graph, StateGraph -from openai import OpenAI -from backend.factory import backend from lib.logger import configure_logger logger = configure_logger(__name__) @@ -242,333 +245,361 @@ def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: pass -class PlanningCapability(BaseWorkflowMixin): - """Mixin that adds planning capabilities to a workflow.""" +@dataclass +class MessageContent: + """Data class for message content""" - async def create_plan(self, query: str, **kwargs) -> str: - """Create a plan based on the user's query. + role: str + content: str + tool_calls: Optional[List[Dict]] = None - Args: - query: The user's query to plan for - **kwargs: Additional arguments (callback_handler, etc.) - - Returns: - The generated plan - """ - raise NotImplementedError("PlanningCapability must implement create_plan") - - def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: - """Integrate planning capability with a graph. - - This adds the planning capability to the graph by modifying - the entry point to first create a plan. - - Args: - graph: The graph to integrate with - **kwargs: Additional arguments specific to planning - """ - # Implementation depends on specific graph structure - raise NotImplementedError( - "PlanningCapability must implement integrate_with_graph" + @classmethod + def from_dict(cls, data: Dict) -> "MessageContent": + """Create MessageContent from dictionary""" + return cls( + role=data.get("role", ""), + content=data.get("content", ""), + tool_calls=data.get("tool_calls"), ) -class VectorRetrievalCapability(BaseWorkflowMixin): - """Mixin that adds vector retrieval capabilities to a workflow.""" +class MessageProcessor: + """Processor for messages""" - def __init__(self, *args, **kwargs): - """Initialize the vector retrieval capability.""" - # Initialize parent class if it exists - super().__init__(*args, **kwargs) if hasattr(super(), "__init__") else None - # Initialize our attributes - self._init_vector_retrieval() - - def _init_vector_retrieval(self) -> None: - """Initialize vector retrieval attributes if not already initialized.""" - if not hasattr(self, "collection_names"): - self.collection_names = ["knowledge_collection", "dao_collection"] - if not hasattr(self, "embeddings"): - self.embeddings = OpenAIEmbeddings() - if not hasattr(self, "vector_results_cache"): - self.vector_results_cache = {} - - async def retrieve_from_vector_store(self, query: str, **kwargs) -> List[Document]: - """Retrieve relevant documents from multiple vector stores. + @staticmethod + def extract_filtered_content(history: List[Dict]) -> List[Dict]: + """Extract and filter content from message history.""" + logger.debug( + f"Starting content extraction from history with {len(history)} messages" + ) + filtered_content = [] - Args: - query: The query to search for - **kwargs: Additional arguments (collection_name, embeddings, etc.) + for message in history: + logger.debug(f"Processing message type: {message.get('role')}") + if message.get("role") in ["user", "assistant"]: + filtered_content.append(MessageContent.from_dict(message).__dict__) - Returns: - List of retrieved documents - """ - try: - # Ensure initialization - self._init_vector_retrieval() - - # Check cache first - if query in self.vector_results_cache: - logger.debug(f"Using cached vector results for query: {query}") - return self.vector_results_cache[query] - - all_documents = [] - limit_per_collection = kwargs.get("limit", 4) - logger.debug( - f"Searching vector store: query={query} | limit_per_collection={limit_per_collection}" - ) - - # Query each collection and gather results - for collection_name in self.collection_names: - try: - # Query vectors using the backend - vector_results = await backend.query_vectors( - collection_name=collection_name, - query_text=query, - limit=limit_per_collection, - embeddings=self.embeddings, + logger.debug( + f"Finished filtering content, extracted {len(filtered_content)} messages" + ) + return filtered_content + + @staticmethod + def convert_to_langchain_messages( + filtered_content: List[Dict], + current_input: str, + persona: Optional[str] = None, + ) -> List[Union[SystemMessage, HumanMessage, AIMessage]]: + """Convert filtered content to LangChain message format.""" + messages = [] + + # Add decisiveness instruction + decisiveness_instruction = "Be decisive and action-oriented. When the user requests something, execute it immediately without asking for confirmation." + + if persona: + logger.debug("Adding persona message with decisiveness instruction") + # Add the decisiveness instruction to the persona + enhanced_persona = f"{persona}\n\n{decisiveness_instruction}" + messages.append(SystemMessage(content=enhanced_persona)) + else: + # If no persona, add the decisiveness instruction as a system message + logger.debug("Adding decisiveness instruction as system message") + messages.append(SystemMessage(content=decisiveness_instruction)) + + for msg in filtered_content: + if msg["role"] == "user": + messages.append(HumanMessage(content=msg["content"])) + else: + content = msg.get("content") or "" + if msg.get("tool_calls"): + messages.append( + AIMessage(content=content, tool_calls=msg["tool_calls"]) ) + else: + messages.append(AIMessage(content=content)) - # Convert to LangChain Documents and add collection source - documents = [ - Document( - page_content=doc.get("page_content", ""), - metadata={ - **doc.get("metadata", {}), - "collection_source": collection_name, - }, - ) - for doc in vector_results - ] - - all_documents.extend(documents) - logger.debug( - f"Retrieved {len(documents)} documents from collection {collection_name}" - ) - except Exception as e: - logger.error( - f"Failed to retrieve from collection {collection_name}: {str(e)}", - exc_info=True, - ) - continue # Continue with other collections if one fails + messages.append(HumanMessage(content=current_input)) + logger.debug(f"Prepared message chain with {len(messages)} total messages") + return messages - logger.debug( - f"Retrieved total of {len(all_documents)} documents from all collections" - ) - # Cache the results - self.vector_results_cache[query] = all_documents +class StreamingCallbackHandler(BaseCallbackHandler): + """Handle callbacks from LangChain and stream results to a queue.""" - return all_documents + def __init__( + self, + queue: asyncio.Queue, + on_llm_new_token: Optional[callable] = None, + on_llm_end: Optional[callable] = None, + ): + """Initialize the callback handler with a queue.""" + self.queue = queue + self.tool_states = {} # Store tool states by invocation ID + self.tool_inputs = {} # Store tool inputs by invocation ID + self.active_tools = {} # Track active tools by name for fallback + self.custom_on_llm_new_token = on_llm_new_token + self.custom_on_llm_end = on_llm_end + # Track the current execution phase + self.current_phase = "processing" # Default phase is processing + + def _ensure_loop(self) -> asyncio.AbstractEventLoop: + """Get the current event loop or create a new one if necessary.""" + try: + loop = asyncio.get_running_loop() + return loop + except RuntimeError: + logger.debug("No running event loop found. Creating a new one.") + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + async def _async_put_to_queue(self, item: Dict) -> None: + """Put an item in the queue asynchronously.""" + try: + await self.queue.put(item) except Exception as e: - logger.error(f"Vector store retrieval failed: {str(e)}", exc_info=True) - return [] + logger.error(f"Failed to put item in queue: {str(e)}") + raise StreamingError(f"Queue operation failed: {str(e)}") - def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: - """Integrate vector retrieval capability with a graph. + def _put_to_queue(self, item: Dict) -> None: + """Put an item in the queue, handling event loop considerations.""" + try: + loop = self._ensure_loop() + if loop.is_running(): + future = asyncio.run_coroutine_threadsafe( + self._async_put_to_queue(item), loop + ) + future.result() + else: + loop.run_until_complete(self._async_put_to_queue(item)) + except Exception as e: + logger.error(f"Failed to put item in queue: {str(e)}") + raise StreamingError(f"Queue operation failed: {str(e)}") - This adds the vector retrieval capability to the graph by adding a node - that can perform vector searches when needed. + def _get_tool_info( + self, invocation_id: Optional[str], tool_name: Optional[str] = None + ) -> Optional[tuple]: + """Get tool information using either invocation_id or tool_name. - Args: - graph: The graph to integrate with - **kwargs: Additional arguments specific to vector retrieval including: - - collection_names: List of collection names to search - - limit_per_collection: Number of results per collection + Returns: + Optional[tuple]: (tool_name, tool_input, invocation_id) if found, None otherwise """ - # Add vector search node - graph.add_node("vector_search", self.retrieve_from_vector_store) - - # Add result processing node if needed - if "process_vector_results" not in graph.nodes: - graph.add_node("process_vector_results", self._process_vector_results) - graph.add_edge("vector_search", "process_vector_results") + if invocation_id and invocation_id in self.tool_states: + return ( + self.tool_states[invocation_id], + self.tool_inputs.get(invocation_id, ""), + invocation_id, + ) + elif tool_name and tool_name in self.active_tools: + active_info = self.active_tools[tool_name] + return (tool_name, active_info["input"], active_info["invocation_id"]) + return None - async def _process_vector_results( - self, vector_results: List[Document], **kwargs - ) -> Dict[str, Any]: - """Process vector search results. + async def process_step( + self, content: str, role: str = "assistant", thought: Optional[str] = None + ) -> None: + """Process a planning step and queue it with the planning status. Args: - vector_results: Results from vector search - **kwargs: Additional processing arguments - - Returns: - Processed results with metadata + content: The planning step content + role: The role associated with the step (usually assistant) + thought: Optional thought process notes """ - return { - "results": vector_results, - "metadata": { - "num_vector_results": len(vector_results), - "collection_sources": list( - set( - doc.metadata.get("collection_source", "unknown") - for doc in vector_results - ) - ), - }, - } - - -class WebSearchCapability(BaseWorkflowMixin): - """Mixin that adds web search capabilities to a workflow using OpenAI Responses API.""" - - def __init__(self, *args, **kwargs): - """Initialize the web search capability.""" - # Initialize parent class if it exists - super().__init__(*args, **kwargs) if hasattr(super(), "__init__") else None - # Initialize our attributes - self._init_web_search() + try: + # Create step message with explicit planning status + current_time = datetime.datetime.now().isoformat() + step_message = { + "type": "step", + "status": "planning", # Explicitly mark as planning phase + "content": content, + "role": role, + "thought": thought + or "Planning Phase", # Default to Planning Phase if thought is not provided + "created_at": current_time, + "planning_only": True, # Mark this content as planning-only to prevent duplication + } - def _init_web_search(self) -> None: - """Initialize web search attributes if not already initialized.""" - if not hasattr(self, "search_results_cache"): - self.search_results_cache = {} - if not hasattr(self, "client"): - self.client = OpenAI() + logger.debug(f"Queuing planning step message with length: {len(content)}") + await self._async_put_to_queue(step_message) + except Exception as e: + logger.error(f"Failed to process planning step: {str(e)}") + raise StreamingError(f"Planning step processing failed: {str(e)}") + + def on_tool_start(self, serialized: Dict, input_str: str, **kwargs) -> None: + """Run when tool starts running.""" + tool_name = serialized.get("name") + if not tool_name: + logger.warning("Tool start called without tool name") + return + + invocation_id = kwargs.get("invocation_id", str(uuid.uuid4())) + + # Store in both tracking systems + self.tool_states[invocation_id] = tool_name + self.tool_inputs[invocation_id] = input_str + self.active_tools[tool_name] = { + "invocation_id": invocation_id, + "input": input_str, + "start_time": datetime.datetime.now(), + } - async def search_web(self, query: str, **kwargs) -> List[Dict[str, Any]]: - """Search the web using OpenAI Responses API. + logger.info( + f"Tool started: {tool_name} (ID: {invocation_id}) with input: {input_str[:100]}..." + ) - Args: - query: The search query - **kwargs: Additional search parameters like user_location and search_context_size + def on_tool_end(self, output: str, **kwargs) -> None: + """Run when tool ends running.""" + invocation_id = kwargs.get("invocation_id") + tool_name = kwargs.get("name") # Try to get tool name from kwargs - Returns: - List of search results with content and metadata - """ - try: - # Ensure initialization - self._init_web_search() - - # Check cache first - if query in self.search_results_cache: - logger.info(f"Using cached results for query: {query}") - return self.search_results_cache[query] - - # Configure web search tool - tool_config = { - "type": "web_search_preview", - "search_context_size": kwargs.get("search_context_size", "medium"), - } + # Try to get tool info from either source + tool_info = self._get_tool_info(invocation_id, tool_name) - # Add user location if provided - if "user_location" in kwargs: - tool_config["user_location"] = kwargs["user_location"] + if tool_info: + tool_name, tool_input, used_invocation_id = tool_info + if hasattr(output, "content"): + output = output.content - # Make the API call - response = self.client.responses.create( - model="gpt-4.1", tools=[tool_config], input=query + self._put_to_queue( + { + "type": "tool", + "tool": tool_name, + "input": tool_input, + "output": str(output), + "status": "processing", # Use "processing" status for tool end + "created_at": datetime.datetime.now().isoformat(), + } + ) + logger.info( + f"Tool {tool_name} (ID: {used_invocation_id}) completed with output length: {len(str(output))}" ) - logger.debug(f"Web search response: {response}") - # Process the response into our document format - documents = [] - - # Access the output text directly - if hasattr(response, "output_text"): - text_content = response.output_text - source_urls = [] - - # Try to extract citations if available - if hasattr(response, "citations"): - source_urls = [ - { - "url": citation.url, - "title": getattr(citation, "title", ""), - "start_index": getattr(citation, "start_index", 0), - "end_index": getattr(citation, "end_index", 0), - } - for citation in response.citations - if hasattr(citation, "url") - ] - - # Ensure we always have at least one URL entry - if not source_urls: - source_urls = [ - { - "url": "No source URL available", - "title": "Generated Response", - "start_index": 0, - "end_index": len(text_content), - } - ] - - # Create document with content - doc = { - "page_content": text_content, - "metadata": { - "type": "web_search_result", - "source_urls": source_urls, - "query": query, - "timestamp": None, - }, - } - documents.append(doc) + # Clean up tracking + if used_invocation_id in self.tool_states: + del self.tool_states[used_invocation_id] + del self.tool_inputs[used_invocation_id] + if tool_name in self.active_tools: + del self.active_tools[tool_name] + else: + logger.warning( + f"Tool end called with unknown invocation ID: {invocation_id} and tool name: {tool_name}" + ) - # Cache the results - self.search_results_cache[query] = documents + def on_tool_error(self, error: Exception, **kwargs) -> None: + """Run when tool errors.""" + invocation_id = kwargs.get("invocation_id") + tool_name = kwargs.get("name") # Try to get tool name from kwargs - logger.info(f"Web search completed with {len(documents)} results") - return documents + # Try to get tool info from either source + tool_info = self._get_tool_info(invocation_id, tool_name) - except Exception as e: - logger.error(f"Web search failed: {str(e)}") - # Return a list with one empty result to prevent downstream errors - return [ + if tool_info: + tool_name, tool_input, used_invocation_id = tool_info + self._put_to_queue( { - "page_content": "Web search failed to return results.", - "metadata": { - "type": "web_search_result", - "source_urls": [ - { - "url": "Error occurred during web search", - "title": "Error", - "start_index": 0, - "end_index": 0, - } - ], - "query": query, - "timestamp": None, - }, + "type": "tool", + "tool": tool_name, + "input": tool_input, + "output": f"Error: {str(error)}", + "status": "error", + "created_at": datetime.datetime.now().isoformat(), } - ] - - def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: - """Integrate web search capability with a graph. + ) + logger.error( + f"Tool {tool_name} (ID: {used_invocation_id}) failed with error: {str(error)}", + exc_info=True, + ) - This adds the web search capability to the graph by adding a node - that can perform web searches when needed. + # Clean up tracking + if used_invocation_id in self.tool_states: + del self.tool_states[used_invocation_id] + del self.tool_inputs[used_invocation_id] + if tool_name in self.active_tools: + del self.active_tools[tool_name] + else: + logger.warning( + f"Tool error called with unknown invocation ID: {invocation_id} and tool name: {tool_name}" + ) - Args: - graph: The graph to integrate with - **kwargs: Additional arguments specific to web search including: - - search_context_size: "low", "medium", or "high" - - user_location: dict with type, country, city, region - """ - # Add web search node - graph.add_node("web_search", self.search_web) + def on_llm_start(self, *args, **kwargs) -> None: + """Run when LLM starts running.""" + logger.info("LLM processing started") - # Add result processing node if needed - if "process_results" not in graph.nodes: - graph.add_node("process_results", self._process_results) - graph.add_edge("web_search", "process_results") + def on_llm_new_token(self, token: str, **kwargs) -> None: + """Run on new token.""" + # Check if we have planning_only in the kwargs + planning_only = kwargs.get("planning_only", False) - async def _process_results( - self, web_results: List[Dict[str, Any]], **kwargs - ) -> Dict[str, Any]: - """Process web search results. + # Handle custom token processing if provided + if self.custom_on_llm_new_token: + try: + # Check if it's a coroutine function and handle accordingly + if asyncio.iscoroutinefunction(self.custom_on_llm_new_token): + # For coroutines, we need to schedule it to run without awaiting + loop = self._ensure_loop() + # Create the coroutine object without calling it + coro = self.custom_on_llm_new_token(token, **kwargs) + # Schedule it to run in the event loop + asyncio.run_coroutine_threadsafe(coro, loop) + else: + # Regular function call + self.custom_on_llm_new_token(token, **kwargs) + except Exception as e: + logger.error(f"Error in custom token handler: {str(e)}", exc_info=True) + + # Log token information with phase information + phase = "planning" if planning_only else "processing" + logger.debug(f"Received new token (length: {len(token)}, phase: {phase})") + + def on_llm_end(self, response, **kwargs) -> None: + """Run when LLM ends running.""" + logger.info("LLM processing completed") + + # Queue an end message with complete status + try: + self._put_to_queue( + { + "type": "token", + "status": "complete", + "content": "", + "created_at": datetime.datetime.now().isoformat(), + } + ) + except Exception as e: + logger.error(f"Failed to queue completion message: {str(e)}") - Args: - web_results: Results from web search - **kwargs: Additional processing arguments + # Handle custom end processing if provided + if self.custom_on_llm_end: + try: + # Check if it's a coroutine function and handle accordingly + if asyncio.iscoroutinefunction(self.custom_on_llm_end): + # For coroutines, we need to schedule it to run without awaiting + loop = self._ensure_loop() + # Create the coroutine object without calling it + coro = self.custom_on_llm_end(response, **kwargs) + # Schedule it to run in the event loop + asyncio.run_coroutine_threadsafe(coro, loop) + else: + # Regular function call + self.custom_on_llm_end(response, **kwargs) + except Exception as e: + logger.error(f"Error in custom end handler: {str(e)}", exc_info=True) + + def on_llm_error(self, error: Exception, **kwargs) -> None: + """Run when LLM errors.""" + logger.error(f"LLM error occurred: {str(error)}", exc_info=True) + + # Send error status + try: + self._put_to_queue( + { + "type": "token", + "status": "error", + "content": f"Error: {str(error)}", + "created_at": datetime.datetime.now().isoformat(), + } + ) + except Exception: + pass # Don't raise another error if this fails - Returns: - Processed results with metadata - """ - return { - "results": web_results, - "metadata": { - "num_web_results": len(web_results), - "source_types": ["web_search"], - }, - } + raise ExecutionError("LLM processing failed", {"error": str(error)}) diff --git a/services/workflows/vector_preplan_react.py b/services/workflows/chat.py similarity index 87% rename from services/workflows/vector_preplan_react.py rename to services/workflows/chat.py index 20067308..22e6309f 100644 --- a/services/workflows/vector_preplan_react.py +++ b/services/workflows/chat.py @@ -29,31 +29,31 @@ from services.workflows.base import ( BaseWorkflow, ExecutionError, - PlanningCapability, + MessageProcessor, + StreamingCallbackHandler, +) +from services.workflows.planning_mixin import PlanningCapability +from services.workflows.vector_mixin import ( VectorRetrievalCapability, - WebSearchCapability, ) -from services.workflows.react import StreamingCallbackHandler - -# Remove this import to avoid circular dependencies -# from services.workflows.workflow_service import BaseWorkflowService, WorkflowBuilder +from services.workflows.web_search_mixin import WebSearchCapability logger = configure_logger(__name__) -class VectorPreplanState(TypedDict): - """State for the Vector PrePlan ReAct workflow, combining both capabilities.""" +class ChatState(TypedDict): + """State for the Chat workflow, combining all capabilities.""" messages: Annotated[list, add_messages] vector_results: Optional[List[Document]] - web_search_results: Optional[List[Document]] # Add web search results + web_search_results: Optional[List[Document]] # Web search results plan: Optional[str] -class VectorPreplanReactWorkflow( - BaseWorkflow[VectorPreplanState], - VectorRetrievalCapability, +class ChatWorkflow( + BaseWorkflow[ChatState], PlanningCapability, + VectorRetrievalCapability, WebSearchCapability, ): """Workflow that combines vector retrieval and planning capabilities. @@ -110,6 +110,18 @@ def __init__( self.persona = None self.tool_descriptions = None + # Initialize mixins + PlanningCapability.__init__( + self, + callback_handler=callback_handler, + planning_llm=self.planning_llm, + persona=self.persona, + tool_names=self.tool_names, + tool_descriptions=self.tool_descriptions, + ) + VectorRetrievalCapability.__init__(self) + WebSearchCapability.__init__(self) + def _create_prompt(self) -> None: """Not used in Vector PrePlan ReAct workflow.""" pass @@ -337,14 +349,14 @@ def _create_graph(self) -> StateGraph: tool_node = ToolNode(self.tools) logger.debug(f"Created tool node with {len(self.tools)} tools") - def should_continue(state: VectorPreplanState) -> str: + def should_continue(state: ChatState) -> str: messages = state["messages"] last_message = messages[-1] result = "tools" if last_message.tool_calls else END logger.debug(f"Continue decision: {result}") return result - async def retrieve_context(state: VectorPreplanState) -> Dict: + async def retrieve_context(state: ChatState) -> Dict: """Retrieve context from both vector store and web search.""" messages = state["messages"] last_user_message = None @@ -373,7 +385,7 @@ async def retrieve_context(state: VectorPreplanState) -> Dict: return {"vector_results": vector_results, "web_search_results": web_results} - def call_model_with_context_and_plan(state: VectorPreplanState) -> Dict: + def call_model_with_context_and_plan(state: ChatState) -> Dict: """Call model with context, plan, and web search results.""" messages = state["messages"] vector_results = state.get("vector_results", []) @@ -443,7 +455,7 @@ def call_model_with_context_and_plan(state: VectorPreplanState) -> Dict: response = self.llm.invoke(messages) return {"messages": [response]} - workflow = StateGraph(VectorPreplanState) + workflow = StateGraph(ChatState) # Add nodes workflow.add_node("context_retrieval", retrieve_context) @@ -460,33 +472,27 @@ def call_model_with_context_and_plan(state: VectorPreplanState) -> Dict: return workflow -class VectorPreplanLangGraphService: - """Service for executing Vector PrePlan React LangGraph operations""" +class ChatService: + """Service for executing Chat LangGraph operations.""" def __init__( self, collection_names: Union[str, List[str]], embeddings: Optional[Embeddings] = None, ): - # Import here to avoid circular imports - from services.workflows.react import MessageProcessor self.collection_names = collection_names self.embeddings = embeddings or OpenAIEmbeddings() self.message_processor = MessageProcessor() def setup_callback_handler(self, queue, loop): - # Import here to avoid circular dependencies from services.workflows.workflow_service import BaseWorkflowService - # Use the static method instead of instantiating BaseWorkflowService return BaseWorkflowService.create_callback_handler(queue, loop) async def stream_task_results(self, task, queue): - # Import here to avoid circular dependencies from services.workflows.workflow_service import BaseWorkflowService - # Use the static method instead of instantiating BaseWorkflowService async for chunk in BaseWorkflowService.stream_results_from_task( task=task, callback_queue=queue, logger_name=self.__class__.__name__ ): @@ -500,32 +506,14 @@ async def _execute_stream_impl( tools_map: Optional[Dict] = None, **kwargs, ) -> AsyncGenerator[Dict, None]: - """Execute a Vector PrePlan React stream implementation. - - Args: - messages: Processed messages - input_str: Current user input - persona: Optional persona to use - tools_map: Optional tools to use - **kwargs: Additional arguments - - Returns: - Async generator of result chunks - """ try: - # Import here to avoid circular dependencies from services.workflows.workflow_service import WorkflowBuilder - # Setup queue and callbacks callback_queue = asyncio.Queue() loop = asyncio.get_running_loop() - - # Setup callback handler callback_handler = self.setup_callback_handler(callback_queue, loop) - - # Create workflow using builder pattern workflow = ( - WorkflowBuilder(VectorPreplanReactWorkflow) + WorkflowBuilder(ChatWorkflow) .with_callback_handler(callback_handler) .with_tools(list(tools_map.values()) if tools_map else []) .build( @@ -533,17 +521,11 @@ async def _execute_stream_impl( embeddings=self.embeddings, ) ) - - # Store persona and tool information for planning if persona: - # Append decisiveness guidance to the persona decisive_guidance = "\n\nBe decisive and take action without asking for confirmation. When the user requests something, proceed directly with executing it rather than asking if they want you to do it." workflow.persona = persona + decisive_guidance - - # Store available tool names for planning if tools_map: workflow.tool_names = list(tools_map.keys()) - # Add tool descriptions to planning prompt tool_descriptions = "\n\nTOOL DESCRIPTIONS:\n" for name, tool in tools_map.items(): description = getattr( @@ -551,17 +533,12 @@ async def _execute_stream_impl( ) tool_descriptions += f"- {name}: {description}\n" workflow.tool_descriptions = tool_descriptions - - # First retrieve relevant documents from vector store logger.info( f"Retrieving documents from vector store for query: {input_str[:50]}..." ) documents = await workflow.retrieve_from_vector_store(query=input_str) logger.info(f"Retrieved {len(documents)} documents from vector store") - - # Create plan with vector context try: - # The thought notes will be streamed through callbacks logger.info("Creating plan with vector context...") plan = await workflow.create_plan(input_str, context_docs=documents) logger.info(f"Plan created successfully with {len(plan)} characters") @@ -571,15 +548,10 @@ async def _execute_stream_impl( "type": "token", "content": "Proceeding directly to answer...\n\n", } - # No plan will be provided, letting the LLM handle the task naturally plan = None - - # Create graph and compile graph = workflow._create_graph() runnable = graph.compile() logger.info("Graph compiled successfully") - - # Execute workflow with callbacks config config = {"callbacks": [callback_handler]} task = asyncio.create_task( runnable.ainvoke( @@ -587,18 +559,12 @@ async def _execute_stream_impl( config=config, ) ) - - # Stream results async for chunk in self.stream_task_results(task, callback_queue): yield chunk - except Exception as e: - logger.error( - f"Failed to execute Vector PrePlan stream: {str(e)}", exc_info=True - ) - raise ExecutionError(f"Vector PrePlan stream execution failed: {str(e)}") + logger.error(f"Failed to execute Chat stream: {str(e)}", exc_info=True) + raise ExecutionError(f"Chat stream execution failed: {str(e)}") - # Add execute_stream method to maintain the same interface as BaseWorkflowService async def execute_stream( self, history: List[Dict], @@ -607,17 +573,10 @@ async def execute_stream( tools_map: Optional[Dict] = None, **kwargs, ) -> AsyncGenerator[Dict, None]: - """Execute a workflow stream. - - This processes the history and delegates to _execute_stream_impl. - """ - # Process messages filtered_content = self.message_processor.extract_filtered_content(history) messages = self.message_processor.convert_to_langchain_messages( filtered_content, input_str, persona ) - - # Call the implementation async for chunk in self._execute_stream_impl( messages=messages, input_str=input_str, @@ -629,7 +588,7 @@ async def execute_stream( # Facade function -async def execute_vector_preplan_stream( +async def execute_chat_stream( collection_names: Union[str, List[str]], history: List[Dict], input_str: str, @@ -637,30 +596,17 @@ async def execute_vector_preplan_stream( tools_map: Optional[Dict] = None, embeddings: Optional[Embeddings] = None, ) -> AsyncGenerator[Dict, None]: - """Execute a Vector PrePlan ReAct stream. + """Execute a Chat stream. This workflow combines vector retrieval and planning: 1. Retrieves relevant context from multiple vector stores 2. Creates a plan based on the user's query and retrieved context 3. Executes the ReAct workflow with both context and plan - - Args: - collection_names: Name(s) of the vector collections to use - history: Conversation history - input_str: Current user input - persona: Optional persona to use - tools_map: Optional tools to make available - embeddings: Optional embeddings model - - Returns: - Async generator of result chunks """ - # Initialize service and run stream embeddings = embeddings or OpenAIEmbeddings() - service = VectorPreplanLangGraphService( + service = ChatService( collection_names=collection_names, embeddings=embeddings, ) - async for chunk in service.execute_stream(history, input_str, persona, tools_map): yield chunk diff --git a/services/workflows/planning_mixin.py b/services/workflows/planning_mixin.py new file mode 100644 index 00000000..e97c71f3 --- /dev/null +++ b/services/workflows/planning_mixin.py @@ -0,0 +1,178 @@ +"""Planning mixin for workflows, providing vector-aware planning capabilities.""" + +import asyncio +from typing import Any, Dict, List, Optional, Tuple + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI + +from lib.logger import configure_logger +from services.workflows.base import BaseWorkflowMixin +from services.workflows.chat import StreamingCallbackHandler + +logger = configure_logger(__name__) + + +class PlanningCapability(BaseWorkflowMixin): + """Mixin that adds vector-aware planning capabilities to a workflow. + + This mixin generates a plan based on the user's query, retrieved vector context, + available tools, and persona. It streams planning tokens using a callback handler. + """ + + def __init__( + self, + callback_handler: StreamingCallbackHandler, + planning_llm: ChatOpenAI, + persona: Optional[str] = None, + tool_names: Optional[List[str]] = None, + tool_descriptions: Optional[str] = None, + **kwargs, + ): + """Initialize the planning capability. + + Args: + callback_handler: Handler for streaming planning tokens + planning_llm: LLM instance for planning + persona: Optional persona string + tool_names: Optional list of tool names + tool_descriptions: Optional tool descriptions string + **kwargs: Additional arguments + """ + super().__init__(**kwargs) if hasattr(super(), "__init__") else None + self.callback_handler = callback_handler + self.planning_llm = planning_llm + self.persona = persona + self.tool_names = tool_names or [] + self.tool_descriptions = tool_descriptions + + async def create_plan( + self, + query: str, + context_docs: Optional[List[Any]] = None, + **kwargs, + ) -> Tuple[str, Dict[str, Any]]: + """Create a plan based on the user's query and vector retrieval results. + + Args: + query: The user's query + context_docs: Optional retrieved context documents + **kwargs: Additional arguments + + Returns: + Tuple containing the generated plan (str) and token usage (dict) + """ + planning_prompt = f""" + You are an AI assistant planning a decisive response to the user's query. + + Write a few short sentences as if you're taking notes in a notebook about: + - What the user is asking for + - What information or tools you'll use to complete the task + - The exact actions you'll take to fulfill the request + + AIBTC DAO Context Information: + You are an AI governance agent integrated with an AIBTC DAO. Your role is to interact with the DAO's smart contracts + on behalf of token holders, either by assisting human users or by acting autonomously within the DAO's rules. The DAO + is governed entirely by its token holders through proposals – members submit proposals, vote on them, and if a proposal passes, + it is executed on-chain. Always maintain the integrity of the DAO's decentralized process: never bypass on-chain governance, + and ensure all actions strictly follow the DAO's smart contract rules and parameters. + + Your responsibilities include: + 1. Helping users create and submit proposals to the DAO + 2. Guiding users through the voting process + 3. Explaining how DAO contract interactions work + 4. Preventing invalid actions and detecting potential exploits + 5. In autonomous mode, monitoring DAO state, proposing actions, and voting according to governance rules + + When interacting with users about the DAO, always: + - Retrieve contract addresses automatically instead of asking users + - Validate transactions before submission + - Present clear summaries of proposed actions + - Verify eligibility and check voting power + - Format transactions precisely according to blockchain requirements + - Provide confirmation and feedback after actions + + DAO Tools Usage: + For ANY DAO-related request, use the appropriate DAO tools to access real-time information: + - Use dao_list tool to retrieve all DAOs, their tokens, and extensions + - Use dao_search tool to find specific DAOs by name, description, token name, symbol, or contract ID + - Do NOT hardcode DAO information or assumptions about contract addresses + - Always query for the latest DAO data through the tools rather than relying on static information + - When analyzing user requests, determine if they're asking about a specific DAO or need a list of DAOs + - After retrieving DAO information, use it to accurately guide users through governance processes + + Examples of effective DAO tool usage: + 1. If user asks about voting on a proposal: First use dao_search to find the specific DAO, then guide them with the correct contract details + 2. If user asks to list available DAOs: Use dao_list to retrieve current DAOs and present them clearly + 3. If user wants to create a proposal: Use dao_search to get the DAO details first, then assist with the proposal creation using the current contract addresses + + User Query: {query} + """ + if context_docs: + context_str = "\n\n".join( + [getattr(doc, "page_content", str(doc)) for doc in context_docs] + ) + planning_prompt += f"\n\nHere is additional context that may be helpful:\n\n{context_str}\n\nUse this context to inform your plan." + if self.tool_names: + tool_info = "\n\nTools available to you:\n" + for tool_name in self.tool_names: + tool_info += f"- {tool_name}\n" + planning_prompt += tool_info + if self.tool_descriptions: + planning_prompt += self.tool_descriptions + planning_messages = [] + if self.persona: + planning_messages.append(SystemMessage(content=self.persona)) + planning_messages.append(HumanMessage(content=planning_prompt)) + try: + logger.info( + "Creating thought process notes for user query with vector context" + ) + original_new_token = self.callback_handler.custom_on_llm_new_token + + async def planning_token_wrapper(token, **kwargs): + if asyncio.iscoroutinefunction(original_new_token): + await original_new_token(token, planning_only=True, **kwargs) + else: + loop = asyncio.get_running_loop() + asyncio.run_coroutine_threadsafe( + self.callback_handler.queue.put( + { + "type": "token", + "content": token, + "status": "planning", + "planning_only": True, + } + ), + loop, + ) + + self.callback_handler.custom_on_llm_new_token = planning_token_wrapper + task = asyncio.create_task(self.planning_llm.ainvoke(planning_messages)) + response = await task + plan = response.content + token_usage = response.usage_metadata or { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + self.callback_handler.custom_on_llm_new_token = original_new_token + logger.info( + "Thought process notes created successfully with vector context" + ) + logger.debug(f"Notes content length: {len(plan)}") + logger.debug(f"Planning token usage: {token_usage}") + await self.callback_handler.process_step( + content=plan, role="assistant", thought="Planning Phase with Context" + ) + return plan, token_usage + except Exception as e: + if hasattr(self.callback_handler, "custom_on_llm_new_token"): + self.callback_handler.custom_on_llm_new_token = original_new_token + logger.error(f"Failed to create plan: {str(e)}", exc_info=True) + # Return empty plan and zero usage on error + return "Failed to create plan.", { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } diff --git a/services/workflows/preplan_react.py b/services/workflows/preplan_react.py deleted file mode 100644 index 8bd1f3e1..00000000 --- a/services/workflows/preplan_react.py +++ /dev/null @@ -1,481 +0,0 @@ -"""PrePlan ReAct workflow functionality. - -This workflow first creates a plan based on the user's query, then executes -the ReAct workflow to complete the task according to the plan. -""" - -import asyncio -from typing import ( - Annotated, - Any, - AsyncGenerator, - Dict, - List, - Optional, - TypedDict, - Union, -) - -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from langchain_openai import ChatOpenAI -from langgraph.graph import END, START, StateGraph -from langgraph.graph.message import add_messages -from langgraph.prebuilt import ToolNode - -from lib.logger import configure_logger -from services.workflows.base import BaseWorkflow, ExecutionError, PlanningCapability -from services.workflows.react import MessageProcessor, StreamingCallbackHandler - -# Remove this import to avoid circular dependencies -# from services.workflows.workflow_service import BaseWorkflowService, WorkflowBuilder - -logger = configure_logger(__name__) - - -class PreplanState(TypedDict): - """State for the PrePlan ReAct workflow.""" - - messages: Annotated[list, add_messages] - plan: Optional[str] - - -class PreplanReactWorkflow(BaseWorkflow[PreplanState], PlanningCapability): - """PrePlan ReAct workflow implementation. - - This workflow first creates a plan based on the user's query, - then executes the ReAct workflow to complete the task according to the plan. - """ - - def __init__( - self, - callback_handler: StreamingCallbackHandler, - tools: List[Any], - **kwargs, - ): - super().__init__(**kwargs) - self.callback_handler = callback_handler - self.tools = tools - self.required_fields = ["messages"] - # Set decisive behavior flag - self.decisive_behavior = True - - # Create a new LLM instance with the callback handler - self.llm = self.create_llm_with_callbacks([callback_handler]).bind_tools(tools) - - # Create a separate LLM for planning with streaming enabled - self.planning_llm = ChatOpenAI( - model="o4-mini", - streaming=True, # Enable streaming for the planning LLM - callbacks=[callback_handler], - ) - - # Store tool information for planning - self.tool_names = [] - if tools: - self.tool_names = [ - tool.name if hasattr(tool, "name") else str(tool) for tool in tools - ] - - # Additional attributes for planning - self.persona = None - self.tool_descriptions = None - - def _create_prompt(self) -> None: - """Not used in PrePlan ReAct workflow.""" - pass - - async def create_plan(self, query: str) -> str: - """Create a simple thought process plan based on the user's query.""" - # Create a more decisive planning prompt - planning_prompt = f""" - You are an AI assistant planning a decisive response to the user's query. - - Write a few short sentences as if you're taking notes in a notebook about: - - What the user is asking for - - What information or tools you'll use to complete the task - - The exact actions you'll take to fulfill the request - - AIBTC DAO Context Information: - You are an AI governance agent integrated with an AIBTC DAO. Your role is to interact with the DAO's smart contracts - on behalf of token holders, either by assisting human users or by acting autonomously within the DAO's rules. The DAO - is governed entirely by its token holders through proposals – members submit proposals, vote on them, and if a proposal passes, - it is executed on-chain. Always maintain the integrity of the DAO's decentralized process: never bypass on-chain governance, - and ensure all actions strictly follow the DAO's smart contract rules and parameters. - - Your responsibilities include: - 1. Helping users create and submit proposals to the DAO - 2. Guiding users through the voting process - 3. Explaining how DAO contract interactions work - 4. Preventing invalid actions and detecting potential exploits - 5. In autonomous mode, monitoring DAO state, proposing actions, and voting according to governance rules - - When interacting with users about the DAO, always: - - Retrieve contract addresses automatically instead of asking users - - Validate transactions before submission - - Present clear summaries of proposed actions - - Verify eligibility and check voting power - - Format transactions precisely according to blockchain requirements - - Provide confirmation and feedback after actions - - DAO Tools Usage: - For ANY DAO-related request, use the appropriate DAO tools to access real-time information: - - Use dao_list tool to retrieve all DAOs, their tokens, and extensions - - Use dao_search tool to find specific DAOs by name, description, token name, symbol, or contract ID - - Do NOT hardcode DAO information or assumptions about contract addresses - - Always query for the latest DAO data through the tools rather than relying on static information - - When analyzing user requests, determine if they're asking about a specific DAO or need a list of DAOs - - After retrieving DAO information, use it to accurately guide users through governance processes - - Examples of effective DAO tool usage: - 1. If user asks about voting on a proposal: First use dao_search to find the specific DAO, then guide them with the correct contract details - 2. If user asks to list available DAOs: Use dao_list to retrieve current DAOs and present them clearly - 3. If user wants to create a proposal: Use dao_search to get the DAO details first, then assist with the proposal creation using the current contract addresses - - Be decisive and action-oriented. Don't include phrases like "I would," "I could," or "I might." - Instead, use phrases like "I will," "I am going to," and "I'll execute." - Don't ask for confirmation before taking actions - assume the user wants you to proceed. - - User Query: {query} - """ - - # Add available tools to the planning prompt if available - if hasattr(self, "tool_names") and self.tool_names: - tool_info = "\n\nTools available to you:\n" - for tool_name in self.tool_names: - tool_info += f"- {tool_name}\n" - planning_prompt += tool_info - - # Add tool descriptions if available - if hasattr(self, "tool_descriptions"): - planning_prompt += self.tool_descriptions - - # Create planning messages, including persona if available - planning_messages = [] - - # If we're in the service context and persona is available, add it as a system message - if hasattr(self, "persona") and self.persona: - planning_messages.append(SystemMessage(content=self.persona)) - - # Add the planning prompt - planning_messages.append(HumanMessage(content=planning_prompt)) - - try: - logger.info("Creating thought process notes for user query") - - # Configure custom callback for planning to properly mark planning tokens - original_new_token = self.callback_handler.custom_on_llm_new_token - - # Create temporary wrapper to mark planning tokens - async def planning_token_wrapper(token, **kwargs): - # Add planning flag to tokens during the planning phase - if asyncio.iscoroutinefunction(original_new_token): - await original_new_token(token, planning_only=True, **kwargs) - else: - # If it's not a coroutine, assume it's a function that uses run_coroutine_threadsafe - loop = asyncio.get_running_loop() - asyncio.run_coroutine_threadsafe( - self.callback_handler.queue.put( - { - "type": "token", - "content": token, - "status": "planning", - "planning_only": True, - } - ), - loop, - ) - - # Set the temporary wrapper - self.callback_handler.custom_on_llm_new_token = planning_token_wrapper - - # Create a task to invoke the planning LLM - task = asyncio.create_task(self.planning_llm.ainvoke(planning_messages)) - - # Wait for the task to complete - response = await task - plan = response.content - - # Restore original callback - self.callback_handler.custom_on_llm_new_token = original_new_token - - logger.info("Thought process notes created successfully") - logger.debug(f"Notes content length: {len(plan)}") - - # Use the new process_step method to emit the plan with a planning status - await self.callback_handler.process_step( - content=plan, role="assistant", thought="Planning Phase" - ) - - return plan - except Exception as e: - # Restore original callback in case of error - if hasattr(self, "callback_handler") and hasattr( - self.callback_handler, "custom_on_llm_new_token" - ): - self.callback_handler.custom_on_llm_new_token = original_new_token - - logger.error(f"Failed to create plan: {str(e)}", exc_info=True) - # Let the LLM handle the planning naturally without a static fallback - raise - - def _create_graph(self) -> StateGraph: - """Create the PrePlan ReAct workflow graph.""" - logger.info("Creating PrePlan ReAct workflow graph") - tool_node = ToolNode(self.tools) - logger.debug(f"Created tool node with {len(self.tools)} tools") - - def should_continue(state: PreplanState) -> str: - messages = state["messages"] - last_message = messages[-1] - result = "tools" if last_message.tool_calls else END - logger.debug(f"Continue decision: {result}") - return result - - def call_model(state: PreplanState) -> Dict: - logger.debug("Calling model with current state") - messages = state["messages"] - - # Add the plan as a system message if it exists and hasn't been added yet - if state.get("plan") is not None and not any( - isinstance(msg, SystemMessage) and "thought" in msg.content.lower() - for msg in messages - ): - logger.info("Adding thought notes to messages as system message") - plan_message = SystemMessage( - content=f""" - Follow these decisive actions to address the user's query: - - {state["plan"]} - - Execute these steps directly without asking for confirmation. - Be decisive and action-oriented in your responses. - """ - ) - messages = [plan_message] + messages - else: - logger.debug("No thought notes to add or notes already added") - - # If decisive behavior is enabled and there's no plan-related system message, - # add a decisive behavior system message - if getattr(self, "decisive_behavior", False) and not any( - isinstance(msg, SystemMessage) for msg in messages - ): - logger.info("Adding decisive behavior instruction as system message") - decisive_message = SystemMessage( - content="Be decisive and take action without asking for confirmation. " - "When the user requests something, proceed directly with executing it." - ) - messages = [decisive_message] + messages - - logger.debug(f"Invoking LLM with {len(messages)} messages") - response = self.llm.invoke(messages) - logger.debug("Received model response") - logger.debug( - f"Response content length: {len(response.content) if hasattr(response, 'content') else 0}" - ) - return {"messages": [response]} - - workflow = StateGraph(PreplanState) - logger.debug("Created StateGraph") - - workflow.add_node("agent", call_model) - workflow.add_node("tools", tool_node) - workflow.add_edge(START, "agent") - workflow.add_conditional_edges("agent", should_continue) - workflow.add_edge("tools", "agent") - logger.info("Graph setup complete") - - return workflow - - def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: - """Integrate planning capability with the graph. - - Args: - graph: The graph to integrate with - **kwargs: Additional arguments - """ - # Implementation would modify the graph to include planning step - # before the main execution flow - pass - - -class PreplanLangGraphService: - """Service for executing PrePlan LangGraph operations""" - - def __init__(self): - # Initialize message processor here - self.message_processor = MessageProcessor() - - def setup_callback_handler(self, queue, loop): - # Import here to avoid circular dependencies - from services.workflows.workflow_service import BaseWorkflowService - - # Use the static method instead of instantiating BaseWorkflowService - return BaseWorkflowService.create_callback_handler(queue, loop) - - async def stream_task_results(self, task, queue): - # Import here to avoid circular dependencies - from services.workflows.workflow_service import BaseWorkflowService - - # Use the static method instead of instantiating BaseWorkflowService - async for chunk in BaseWorkflowService.stream_results_from_task( - task=task, callback_queue=queue, logger_name=self.__class__.__name__ - ): - yield chunk - - async def _execute_stream_impl( - self, - messages: List[Union[SystemMessage, HumanMessage, AIMessage]], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - **kwargs, - ) -> AsyncGenerator[Dict, None]: - """Execute a PrePlan React stream implementation. - - Args: - messages: Processed messages - input_str: Current user input - persona: Optional persona to use - tools_map: Optional tools to use - **kwargs: Additional arguments - - Returns: - Async generator of result chunks - """ - try: - # Import here to avoid circular dependencies - from services.workflows.workflow_service import WorkflowBuilder - - # Setup queue and callbacks - callback_queue = asyncio.Queue() - loop = asyncio.get_running_loop() - - # Setup callback handler - callback_handler = self.setup_callback_handler(callback_queue, loop) - - # Create workflow using builder pattern - workflow_builder = ( - WorkflowBuilder(PreplanReactWorkflow) - .with_callback_handler(callback_handler) - .with_tools(list(tools_map.values()) if tools_map else []) - ) - - workflow = workflow_builder.build() - - # Store persona and tool information for planning - if persona: - # Append decisiveness guidance to the persona - decisive_guidance = "\n\nBe decisive and take action without asking for confirmation. When the user requests something, proceed directly with executing it rather than asking if they want you to do it." - workflow.persona = persona + decisive_guidance - - # Store available tool names for planning - if tools_map: - workflow.tool_names = list(tools_map.keys()) - # Add tool descriptions to planning prompt - tool_descriptions = "\n\nTOOL DESCRIPTIONS:\n" - for name, tool in tools_map.items(): - description = getattr( - tool, "description", "No description available" - ) - tool_descriptions += f"- {name}: {description}\n" - workflow.tool_descriptions = tool_descriptions - - try: - # The thought notes will be streamed through callbacks - plan = await workflow.create_plan(input_str) - - except Exception as e: - logger.error(f"Planning failed, continuing with execution: {str(e)}") - yield { - "type": "token", - "content": "Proceeding directly to answer...\n\n", - } - # No plan will be provided, letting the LLM handle the task naturally - plan = None - - # Create graph and compile - graph = workflow._create_graph() - runnable = graph.compile() - logger.info("Graph compiled successfully") - - # Add the plan to the initial state - initial_state = {"messages": messages} - if plan is not None: - initial_state["plan"] = plan - logger.info("Added plan to initial state") - else: - logger.warning("No plan available for initial state") - - # Set up configuration with callbacks - config = {"callbacks": [callback_handler]} - logger.debug("Configuration set up with callbacks") - - # Execute workflow with callbacks config - logger.info("Creating task to execute workflow") - task = asyncio.create_task(runnable.ainvoke(initial_state, config=config)) - - # Stream results - async for chunk in self.stream_task_results(task, callback_queue): - yield chunk - - except Exception as e: - logger.error( - f"Failed to execute PrePlan ReAct stream: {str(e)}", exc_info=True - ) - raise ExecutionError(f"PrePlan ReAct stream execution failed: {str(e)}") - - # Add execute_stream method to maintain the same interface as BaseWorkflowService - async def execute_stream( - self, - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - **kwargs, - ) -> AsyncGenerator[Dict, None]: - """Execute a workflow stream. - - This processes the history and delegates to _execute_stream_impl. - """ - # Process messages - filtered_content = self.message_processor.extract_filtered_content(history) - messages = self.message_processor.convert_to_langchain_messages( - filtered_content, input_str, persona - ) - - # Call the implementation - async for chunk in self._execute_stream_impl( - messages=messages, - input_str=input_str, - persona=persona, - tools_map=tools_map, - **kwargs, - ): - yield chunk - - # Keep the old method for backward compatibility - async def execute_preplan_react_stream( - self, - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - ) -> AsyncGenerator[Dict, None]: - """Execute a PrePlan ReAct stream using LangGraph.""" - # Call the new method - async for chunk in self.execute_stream(history, input_str, persona, tools_map): - yield chunk - - -# Facade function for compatibility with the API -async def execute_preplan_react_stream( - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, -) -> AsyncGenerator[Dict, None]: - """Execute a PrePlan ReAct stream using LangGraph with optional persona.""" - service = PreplanLangGraphService() - async for chunk in service.execute_stream(history, input_str, persona, tools_map): - yield chunk diff --git a/services/workflows/proposal_evaluation.py b/services/workflows/proposal_evaluation.py index 5e9f8419..30563b58 100644 --- a/services/workflows/proposal_evaluation.py +++ b/services/workflows/proposal_evaluation.py @@ -1,11 +1,11 @@ """Proposal evaluation workflow.""" -import binascii -from typing import Dict, List, Optional, TypedDict +import asyncio +from typing import Any, Dict, List, Optional, TypedDict -from langchain.callbacks.base import BaseCallbackHandler from langchain.prompts import PromptTemplate -from langchain_core.documents import Document +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_openai import ChatOpenAI from langgraph.graph import END, Graph, StateGraph from pydantic import BaseModel, Field @@ -14,7 +14,6 @@ UUID, ExtensionFilter, Profile, - Prompt, PromptFilter, ProposalType, QueueMessageFilter, @@ -24,10 +23,12 @@ from lib.logger import configure_logger from services.workflows.base import ( BaseWorkflow, - VectorRetrievalCapability, - WebSearchCapability, ) -from services.workflows.vector_react import VectorLangGraphService, VectorReactState +from services.workflows.chat import ChatService, StreamingCallbackHandler +from services.workflows.planning_mixin import PlanningCapability +from services.workflows.utils import calculate_token_cost, decode_hex_parameters +from services.workflows.vector_mixin import VectorRetrievalCapability +from services.workflows.web_search_mixin import WebSearchCapability from tools.dao_ext_action_proposals import VoteOnActionProposalTool from tools.tools_factory import filter_tools_by_names, initialize_tools @@ -38,10 +39,10 @@ class ProposalEvaluationOutput(BaseModel): """Output model for proposal evaluation.""" approve: bool = Field( - description="Whether to approve (true) or reject (false) the proposal" + description="Decision: true to approve (vote FOR), false to reject (vote AGAINST)" ) confidence_score: float = Field( - description="The confidence score for the evaluation (0.0-1.0)" + description="Confidence score for the decision (0.0-1.0)" ) reasoning: str = Field(description="The reasoning behind the evaluation decision") @@ -65,14 +66,25 @@ class EvaluationState(TypedDict): agent_prompts: List[Dict] vector_results: Optional[List[Dict]] recent_tweets: Optional[List[Dict]] - web_search_results: Optional[List[Dict]] # Add field for web search results + web_search_results: Optional[List[Dict]] treasury_balance: Optional[float] - token_usage: Optional[Dict] # Add field for token usage tracking - model_info: Optional[Dict] # Add field for model information + contract_source: Optional[str] + plan: Optional[str] + # Token usage tracking per step + planning_token_usage: Optional[Dict] + web_search_token_usage: Optional[Dict] + evaluation_token_usage: Optional[Dict] + # Model info for cost calculation + evaluation_model_info: Optional[Dict] + planning_model_info: Optional[Dict] + web_search_model_info: Optional[Dict] class ProposalEvaluationWorkflow( - BaseWorkflow[EvaluationState], VectorRetrievalCapability, WebSearchCapability + BaseWorkflow[EvaluationState], + VectorRetrievalCapability, + WebSearchCapability, + PlanningCapability, ): """Workflow for evaluating DAO proposals and voting automatically.""" @@ -91,7 +103,25 @@ def __init__( temperature: Optional temperature setting for the model **kwargs: Additional arguments passed to parent """ + # Initialize planning LLM + planning_llm = ChatOpenAI( + model="o4-mini", + stream_usage=True, + streaming=True, + ) + + # Create callback handler for planning with queue + callback_handler = StreamingCallbackHandler(queue=asyncio.Queue()) + + # Initialize all parent classes including PlanningCapability super().__init__(model_name=model_name, temperature=temperature, **kwargs) + PlanningCapability.__init__( + self, + callback_handler=callback_handler, + planning_llm=planning_llm, + persona="You are a DAO proposal evaluation planner, focused on creating structured evaluation plans.", + ) + self.collection_names = collection_names or [ "knowledge_collection", "dao_collection", @@ -200,46 +230,117 @@ def _create_graph(self) -> Graph: """Create the evaluation graph.""" prompt = self._create_prompt() - # Create evaluation node - async def evaluate_proposal(state: EvaluationState) -> EvaluationState: - """Evaluate the proposal and determine how to vote.""" + async def fetch_context(state: EvaluationState) -> EvaluationState: + """Fetch context including web search, vector results, tweets, and contract source.""" try: - # Get proposal data from state - proposal_data = state["proposal_data"] - dao_id = state.get("dao_info", {}).get("id") + # --- Fetch Core Data --- # + proposal_id = state["proposal_id"] + dao_id = state.get("dao_id") + agent_id = state.get("agent_id") + + # Get proposal data + proposal_data = backend.get_proposal(proposal_id) + if not proposal_data: + raise ValueError(f"Proposal {proposal_id} not found") + + # Decode parameters if they exist + decoded_parameters = decode_hex_parameters(proposal_data.parameters) + + # Convert proposal data to dictionary + proposal_dict = { + "proposal_id": proposal_data.proposal_id, + "parameters": decoded_parameters or proposal_data.parameters, + "action": proposal_data.action, + "caller": proposal_data.caller, + "contract_principal": proposal_data.contract_principal, + "creator": proposal_data.creator, + "created_at_block": proposal_data.created_at_block, + "end_block": proposal_data.end_block, + "start_block": proposal_data.start_block, + "liquid_tokens": proposal_data.liquid_tokens, + "type": proposal_data.type, + "proposal_contract": proposal_data.proposal_contract, + } + state["proposal_data"] = proposal_dict # Update state with full data - # Perform web search for relevant context - try: - # Create search query from proposal data - web_search_query = f"DAO proposal {proposal_data.get('type', 'unknown')} - {proposal_data.get('parameters', '')}" + # Get DAO info (if dao_id wasn't passed explicitly, use proposal's) + if not dao_id and proposal_data.dao_id: + dao_id = proposal_data.dao_id + state["dao_id"] = dao_id # Update state if derived - # Use web search capability - web_search_results = await self.search_web( - query=web_search_query, - search_context_size="medium", # Use medium context size for balanced results - ) + dao_info = None + if dao_id: + dao_info = backend.get_dao(dao_id) + if not dao_info: + raise ValueError(f"DAO Information not found for ID: {dao_id}") + state["dao_info"] = dao_info.model_dump() + + # Get agent prompts + agent_prompts_text = [] + if agent_id: + try: + prompts = backend.list_prompts( + PromptFilter( + agent_id=agent_id, + dao_id=dao_id, + is_active=True, + ) + ) + agent_prompts_text = [p.prompt_text for p in prompts] + except Exception as e: + self.logger.error( + f"Failed to get agent prompts: {str(e)}", exc_info=True + ) + state["agent_prompts"] = agent_prompts_text - # Update state with web search results - state["web_search_results"] = web_search_results - self.logger.debug( - f"Web search query: {web_search_query} | Results count: {len(web_search_results)}" - ) - self.logger.debug( - f"Retrieved {len(web_search_results)} web search results" + # Get treasury balance + treasury_balance = None + try: + treasury_extensions = backend.list_extensions( + ExtensionFilter(dao_id=dao_info.id, type="EXTENSIONS_TREASURY") ) + if treasury_extensions: + hiro_api = HiroApi() + treasury_balance = hiro_api.get_address_balance( + treasury_extensions[0].contract_principal + ) + else: + self.logger.warning( + f"No treasury extension for DAO {dao_info.id}" + ) except Exception as e: self.logger.error( - f"Failed to perform web search: {str(e)}", exc_info=True + f"Failed to get treasury balance: {str(e)}", exc_info=True ) - state["web_search_results"] = [] + state["treasury_balance"] = treasury_balance + # --- End Fetch Core Data --- # - # Fetch recent tweets from queue if dao_id exists + # Use mixin capabilities for web search and vector retrieval + web_search_query = f"DAO proposal {proposal_dict.get('type', 'unknown')} - {proposal_dict.get('parameters', '')}" + + # Fetch web search results and token usage + web_search_results, web_search_token_usage = await self.search_web( + query=web_search_query, + search_context_size="medium", + ) + state["web_search_results"] = web_search_results + state["web_search_token_usage"] = web_search_token_usage + # Store web search model info (assuming gpt-4.1 as used in mixin) + state["web_search_model_info"] = { + "name": "gpt-4.1", + "temperature": None, + } + + vector_search_query = f"Proposal type: {proposal_dict.get('type')} - {proposal_dict.get('parameters', '')}" + state["vector_results"] = await self.retrieve_from_vector_store( + query=vector_search_query, limit=5 + ) + + # Fetch recent tweets recent_tweets = [] if dao_id: try: - # Add debug logging for dao_id self.logger.debug(f"Fetching tweets for DAO ID: {dao_id}") - queue_messages = backend.list_queue_messages( QueueMessageFilter( type=QueueMessageType.TWEET, @@ -247,62 +348,38 @@ async def evaluate_proposal(state: EvaluationState) -> EvaluationState: is_processed=True, ) ) - # Log the number of messages found - self.logger.debug(f"Found {len(queue_messages)} queue messages") - - # Sort by created_at and take last 5 sorted_messages = sorted( queue_messages, key=lambda x: x.created_at, reverse=True )[:5] - self.logger.debug(f"After sorting, have {len(sorted_messages)} messages") - recent_tweets = [ { "created_at": msg.created_at, - "message": msg.message.get('message', 'No text available') if isinstance(msg.message, dict) else msg.message, + "message": ( + msg.message.get("message", "No text available") + if isinstance(msg.message, dict) + else msg.message + ), "tweet_id": msg.tweet_id, } for msg in sorted_messages ] - self.logger.debug(f"Retrieved tweets: {recent_tweets}") - self.logger.debug( - f"Found {len(recent_tweets)} recent tweets for DAO {dao_id}" - ) except Exception as e: self.logger.error( - f"Failed to fetch recent tweets: {str(e)}", exc_info=True + f"Failed to fetch tweets: {str(e)}", exc_info=True ) - recent_tweets = [] - - # Update state with recent tweets state["recent_tweets"] = recent_tweets - # If this is a core proposal, fetch the contract source + # Fetch contract source for core proposals contract_source = "" - if proposal_data.get("type") == "core" and proposal_data.get( + if proposal_dict.get("type") == ProposalType.CORE and proposal_dict.get( "proposal_contract" ): - # Split contract address into parts - parts = proposal_data["proposal_contract"].split(".") + parts = proposal_dict["proposal_contract"].split(".") if len(parts) >= 2: - contract_address = parts[0] - contract_name = parts[1] - - # Use HiroApi to fetch contract source try: api = HiroApi() - result = api.get_contract_source( - contract_address, contract_name - ) - if "source" in result: - contract_source = result["source"] - self.logger.debug( - f"Retrieved contract source for {contract_address}.{contract_name}" - ) - else: - self.logger.warning( - f"Contract source not found in API response: {result}" - ) + result = api.get_contract_source(parts[0], parts[1]) + contract_source = result.get("source", "") except Exception as e: self.logger.error( f"Failed to fetch contract source: {str(e)}", @@ -310,167 +387,153 @@ async def evaluate_proposal(state: EvaluationState) -> EvaluationState: ) else: self.logger.warning( - f"Invalid contract address format: {proposal_data['proposal_contract']}" + f"Invalid contract format: {proposal_dict['proposal_contract']}" ) + state["contract_source"] = contract_source - # Retrieve relevant context from vector store - try: - # Create search query from proposal data - search_query = f"Proposal type: {proposal_data.get('type')} - {proposal_data.get('parameters', '')}" - - # Use vector retrieval capability - vector_results = await self.retrieve_from_vector_store( - query=search_query, limit=5 # Get top 5 most relevant documents - ) - - # Update state with vector results - state["vector_results"] = vector_results - self.logger.debug( - f"Searching vector store with query: {search_query} | Collection count: {len(self.collection_names)}" - ) - self.logger.debug(f"Vector search results: {vector_results}") - self.logger.debug( - f"Retrieved {len(vector_results)} relevant documents from vector store" - ) - - # Format vector context for prompt - vector_context = "\n\n".join( - [ - f"Related Context {i+1}:\n{doc.page_content}" - for i, doc in enumerate(vector_results) - ] - ) - except Exception as e: - self.logger.error( - f"Failed to retrieve from vector store: {str(e)}", exc_info=True - ) - vector_context = ( - "No additional context available from vector store." - ) + # Validate proposal data structure (moved from entry point) + proposal_type = proposal_dict.get("type") + if proposal_type == ProposalType.ACTION and not proposal_dict.get( + "parameters" + ): + raise ValueError("Action proposal missing parameters") + if proposal_type == ProposalType.CORE and not proposal_dict.get( + "proposal_contract" + ): + raise ValueError("Core proposal missing proposal_contract") - # Format prompt with state - self.logger.debug("Preparing evaluation prompt...") + return state + except Exception as e: + self.logger.error(f"Error in fetch_context: {str(e)}", exc_info=True) + state["reasoning"] = f"Error fetching context: {str(e)}" + # Propagate error state + return state - # Format agent prompts as a string + async def format_evaluation_prompt(state: EvaluationState) -> EvaluationState: + """Format the evaluation prompt using the fetched context.""" + if "reasoning" in state and "Error" in state["reasoning"]: + return state # Skip if context fetching failed + try: + # Extract data from state for easier access + proposal_data = state["proposal_data"] + dao_info = state.get("dao_info", {}) + treasury_balance = state.get("treasury_balance") + contract_source = state.get("contract_source", "") + agent_prompts = state.get("agent_prompts", []) + vector_results = state.get("vector_results", []) + recent_tweets = state.get("recent_tweets", []) + web_search_results = state.get("web_search_results", []) + + # Format agent prompts agent_prompts_str = "No agent-specific instructions available." - if state.get("agent_prompts"): - self.logger.debug(f"Raw agent prompts: {state['agent_prompts']}") - if ( - isinstance(state["agent_prompts"], list) - and state["agent_prompts"] - ): - # Just use the prompt text directly since that's what we're storing - agent_prompts_str = "\n\n".join(state["agent_prompts"]) - self.logger.debug( - f"Formatted agent prompts: {agent_prompts_str}" - ) + if agent_prompts: + if isinstance(agent_prompts, list): + agent_prompts_str = "\n\n".join(agent_prompts) else: self.logger.warning( - f"Invalid agent prompts format: {type(state['agent_prompts'])}" + f"Invalid agent prompts: {type(agent_prompts)}" ) - else: - self.logger.debug("No agent prompts found in state") - # Format web search results for prompt + # Format web search results web_search_content = "No relevant web search results found." - if state.get("web_search_results"): + if web_search_results: web_search_content = "\n\n".join( [ - f"Web Result {i+1}:\n{result['page_content']}\nSource: {result['metadata']['source_urls'][0]['url'] if result['metadata']['source_urls'] else 'Unknown'}" - for i, result in enumerate(state["web_search_results"]) + f"Web Result {i+1}:\n{res.get('page_content', '')}\nSource: {res.get('metadata', {}).get('source_urls', [{}])[0].get('url', 'Unknown')}" + for i, res in enumerate(web_search_results) ] ) - # Update formatted prompt with web search results - formatted_prompt = self._create_prompt().format( + # Format vector context + vector_context = "No additional context available from vector store." + if vector_results: + vector_context = "\n\n".join( + [ + f"Related Context {i+1}:\n{doc.page_content}" + for i, doc in enumerate(vector_results) + ] + ) + + # Format recent tweets + tweets_content = "\n".join( + [ + f"Tweet {i+1} ({tweet['created_at']}): {tweet['message']}" + for i, tweet in enumerate(recent_tweets) + ] + ) + + formatted_prompt = prompt.format( proposal_data=proposal_data, - dao_info=state.get( - "dao_info", "No additional DAO information available." - ), - treasury_balance=state.get("treasury_balance"), + dao_info=dao_info, + treasury_balance=treasury_balance, contract_source=contract_source, agent_prompts=agent_prompts_str, vector_context=vector_context, - recent_tweets=( - "\n".join( - [ - f"Tweet {i+1} ({tweet['created_at']}): {tweet['message']}" - for i, tweet in enumerate(recent_tweets) - ] - ) - if recent_tweets - else "No recent tweets available." - ), + recent_tweets=tweets_content, web_search_results=web_search_content, ) + state["formatted_prompt"] = formatted_prompt + return state + except Exception as e: + self.logger.error(f"Error formatting prompt: {str(e)}", exc_info=True) + state["reasoning"] = f"Error formatting prompt: {str(e)}" + return state - # Get evaluation from LLM - self.logger.debug("Starting LLM evaluation...") + async def call_evaluation_llm(state: EvaluationState) -> EvaluationState: + """Call the LLM with the formatted prompt for evaluation.""" + if "reasoning" in state and "Error" in state["reasoning"]: + return state # Skip if previous steps failed + try: structured_output = self.llm.with_structured_output( - ProposalEvaluationOutput, - include_raw=True, # Include raw response to get token usage + ProposalEvaluationOutput, include_raw=True + ) + result: Dict[str, Any] = await structured_output.ainvoke( + state["formatted_prompt"] ) - # Invoke LLM with formatted prompt - result = structured_output.invoke(formatted_prompt) - - # Extract the parsed result and token usage from raw response - self.logger.debug( - f"Raw LLM result structure: {type(result).__name__} | Has parsed: {'parsed' in result if isinstance(result, dict) else False}" + result: Dict[str, Any] = await structured_output.ainvoke( + state["formatted_prompt"] ) - parsed_result = result["parsed"] if isinstance(result, dict) else result - model_info = {"name": self.model_name, "temperature": self.temperature} - if isinstance(result, dict) and "raw" in result: - raw_msg = result["raw"] - # Extract token usage - if hasattr(raw_msg, "usage_metadata"): - token_usage = raw_msg.usage_metadata - self.logger.debug( - f"Token usage details: input={token_usage.get('input_tokens', 0)} | output={token_usage.get('output_tokens', 0)} | total={token_usage.get('total_tokens', 0)}" + parsed_result = result.get("parsed") + if not isinstance(parsed_result, ProposalEvaluationOutput): + # Attempt to handle cases where parsing might return the raw dict + if isinstance(parsed_result, dict): + parsed_result = ProposalEvaluationOutput(**parsed_result) + else: + raise TypeError( + f"Expected ProposalEvaluationOutput or dict, got {type(parsed_result)}" ) + + model_info = {"name": self.model_name, "temperature": self.temperature} + token_usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + raw_response = result.get("raw") + if raw_response: + if hasattr(raw_response, "usage_metadata"): + token_usage = raw_response.usage_metadata else: - self.logger.warning("No usage_metadata found in raw response") - token_usage = { - "input_tokens": 0, - "output_tokens": 0, - "total_tokens": 0, - } + self.logger.warning("Raw response missing usage_metadata") else: - self.logger.warning("No raw response available") - token_usage = { - "input_tokens": 0, - "output_tokens": 0, - "total_tokens": 0, - } + self.logger.warning("LLM result missing raw response data") - self.logger.debug(f"Parsed evaluation result: {parsed_result}") - - # Update state - state["formatted_prompt"] = formatted_prompt state["approve"] = parsed_result.approve state["confidence_score"] = parsed_result.confidence_score state["reasoning"] = parsed_result.reasoning - state["token_usage"] = token_usage - state["model_info"] = model_info - - # Calculate token costs - token_costs = calculate_token_cost(token_usage, model_info["name"]) + state["evaluation_token_usage"] = token_usage + state["evaluation_model_info"] = model_info - # Log final evaluation summary self.logger.debug( - f"Evaluation complete: Decision={'APPROVE' if parsed_result.approve else 'REJECT'} | Confidence={parsed_result.confidence_score:.2f} | Model={model_info['name']} (temp={model_info['temperature']}) | Tokens={token_usage} | Cost=${token_costs['total_cost']:.4f}" + f"Evaluation step complete: Decision={'APPROVE' if parsed_result.approve else 'REJECT'} | Confidence={parsed_result.confidence_score:.2f}" ) self.logger.debug(f"Full reasoning: {parsed_result.reasoning}") return state except Exception as e: - self.logger.error( - f"Error in evaluate_proposal: {str(e)}", exc_info=True - ) + self.logger.error(f"Error calling LLM: {str(e)}", exc_info=True) state["approve"] = False state["confidence_score"] = 0.0 - state["reasoning"] = f"Error during evaluation: {str(e)}" + state["reasoning"] = f"Error during LLM evaluation: {str(e)}" return state # Create decision node @@ -503,6 +566,17 @@ async def should_vote(state: EvaluationState) -> str: async def vote_on_proposal(state: EvaluationState) -> EvaluationState: """Vote on the proposal using VectorReact workflow.""" try: + # Check if wallet_id is available + if not state.get("wallet_id"): + self.logger.warning( + "No wallet_id provided for voting, skipping vote" + ) + state["vote_result"] = { + "success": False, + "error": "No wallet_id provided for voting", + } + return state + self.logger.debug( f"Setting up VectorReact workflow: proposal_id={state['proposal_id']} | vote={state['approve']}" ) @@ -515,7 +589,7 @@ async def vote_on_proposal(state: EvaluationState) -> EvaluationState: vote_instruction = f"I need you to vote on a DAO proposal with ID {state['proposal_id']} in the contract {state['action_proposals_contract']}. Please vote {'FOR' if state['approve'] else 'AGAINST'} the proposal. Use the dao_action_vote_on_proposal tool to submit the vote." # Create VectorLangGraph service with collections - service = VectorLangGraphService( + service = ChatService( collection_names=self.collection_names, ) @@ -591,16 +665,79 @@ async def skip_voting(state: EvaluationState) -> EvaluationState: } return state + # --- Planning Node --- # + async def plan_evaluation(state: EvaluationState) -> EvaluationState: + """Generate a plan for evaluating the proposal using the PlanningCapability mixin.""" + try: + self.logger.debug( + "Generating evaluation plan using PlanningCapability..." + ) + + # Construct initial context for planning + initial_context = ( + f"Proposal ID: {state['proposal_id']}\n" + f"DAO ID: {state.get('dao_id')}\n" + f"Agent ID: {state.get('agent_id')}\n" + f"Auto-Vote Enabled: {state.get('auto_vote')}" + ) + + # Create planning query + planning_query = ( + f"Create a detailed plan for evaluating the following DAO proposal:\n\n" + f"{initial_context}\n\n" + f"The plan should cover:\n" + f"1. Data gathering (proposal details, DAO context, treasury info)\n" + f"2. Analysis approach (including use of vector search and web search)\n" + f"3. Evaluation criteria and decision making process\n" + f"4. Voting execution strategy (if auto-vote is enabled)" + ) + + # Use the mixin's create_plan method + plan, planning_token_usage = await self.create_plan( + query=planning_query, context_docs=state.get("vector_results", []) + ) + + state["plan"] = plan + state["planning_token_usage"] = planning_token_usage + # Store planning model info + state["planning_model_info"] = { + "name": self.planning_llm.model_name, + "temperature": self.planning_llm.temperature, + } + + self.logger.info("Evaluation plan generated using PlanningCapability.") + self.logger.debug(f"Generated Plan:\n{plan}") + return state + + except Exception as e: + self.logger.error(f"Error generating plan: {str(e)}", exc_info=True) + state["plan"] = f"Error generating plan: {str(e)}" + state["planning_token_usage"] = { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + state["planning_model_info"] = {"name": "unknown", "temperature": None} + return state + # Create the graph workflow = StateGraph(EvaluationState) # Add nodes - workflow.add_node("evaluate", evaluate_proposal) + workflow.add_node("plan_evaluation", plan_evaluation) # New planning node + workflow.add_node("fetch_context", fetch_context) + workflow.add_node("format_prompt", format_evaluation_prompt) + workflow.add_node( + "evaluate", call_evaluation_llm + ) # Renamed from evaluate_proposal workflow.add_node("vote", vote_on_proposal) workflow.add_node("skip_vote", skip_voting) # Set up the conditional branching - workflow.set_entry_point("evaluate") + workflow.set_entry_point("plan_evaluation") # Start with planning + workflow.add_edge("plan_evaluation", "fetch_context") # Plan -> Fetch + workflow.add_edge("fetch_context", "format_prompt") + workflow.add_edge("format_prompt", "evaluate") workflow.add_conditional_edges( "evaluate", should_vote, @@ -616,11 +753,13 @@ async def skip_voting(state: EvaluationState) -> EvaluationState: def _validate_state(self, state: EvaluationState) -> bool: """Validate the workflow state.""" - required_fields = ["proposal_id", "proposal_data"] + # Only validate minimal required fields for initial state + # Other fields like proposal_data are fetched within the workflow + required_fields = ["proposal_id"] # Log the state for debugging self.logger.debug( - f"Validating state: proposal_id={state.get('proposal_id')} | proposal_type={state.get('proposal_data', {}).get('type', 'unknown')}" + f"Validating initial state: proposal_id={state.get('proposal_id')}" ) # Check all fields and log problems @@ -632,30 +771,8 @@ def _validate_state(self, state: EvaluationState) -> bool: self.logger.error(f"Empty required field: {field}") return False - # Get proposal type - proposal_type = state["proposal_data"].get("type", ProposalType.ACTION) - - # Validate based on proposal type - if proposal_type == ProposalType.ACTION: - # Action proposals require action_proposals_contract and parameters - if not state.get("action_proposals_contract"): - self.logger.error( - "Missing action_proposals_contract for action proposal" - ) - return False - if not state["proposal_data"].get("parameters"): - self.logger.error("No parameters field in action proposal data") - return False - elif proposal_type == ProposalType.CORE: - # Core proposals require proposal_contract - if not state["proposal_data"].get("proposal_contract"): - self.logger.error("Missing proposal_contract for core proposal") - return False - else: - self.logger.error(f"Invalid proposal type: {proposal_type}") - return False - - self.logger.debug("State validation successful") + # Note: Detailed validation of proposal_data happens in fetch_context node + self.logger.debug("Initial state validation successful") return True @@ -691,111 +808,10 @@ def get_proposal_evaluation_tools( return filtered_tools -def decode_hex_parameters(hex_string: Optional[str]) -> Optional[str]: - """Decodes a hexadecimal-encoded string if valid.""" - if not hex_string: - return None - if hex_string.startswith("0x"): - hex_string = hex_string[2:] # Remove "0x" prefix - try: - decoded_bytes = binascii.unhexlify(hex_string) - decoded_string = decoded_bytes.decode( - "utf-8", errors="ignore" - ) # Decode as UTF-8 - logger.debug(f"Successfully decoded hex string: {hex_string[:20]}...") - return decoded_string - except (binascii.Error, UnicodeDecodeError) as e: - logger.warning(f"Failed to decode hex string: {str(e)}") - return None # Return None if decoding fails - - -def calculate_token_cost( - token_usage: Dict[str, int], model_name: str -) -> Dict[str, float]: - """Calculate the cost of token usage based on current pricing. - - Args: - token_usage: Dictionary containing input_tokens and output_tokens - model_name: Name of the model used - - Returns: - Dictionary containing cost breakdown and total cost - """ - # Current pricing per million tokens (as of August 2024) - MODEL_PRICES = { - "gpt-4o": { - "input": 2.50, # $2.50 per million input tokens - "output": 10.00, # $10.00 per million output tokens - }, - "gpt-4.1": { - "input": 2.00, # $2.00 per million input tokens - "output": 8.00, # $8.00 per million output tokens - }, - "gpt-4.1-mini": { - "input": 0.40, # $0.40 per million input tokens - "output": 1.60, # $1.60 per million output tokens - }, - "gpt-4.1-nano": { - "input": 0.10, # $0.10 per million input tokens - "output": 0.40, # $0.40 per million output tokens - }, - # Default to gpt-4.1 pricing if model not found - "default": { - "input": 2.00, - "output": 8.00, - }, - } - - # Get pricing for the model, default to gpt-4.1 pricing if not found - model_prices = MODEL_PRICES.get(model_name.lower(), MODEL_PRICES["default"]) - - # Extract token counts, ensuring we get integers and handle None values - try: - input_tokens = int(token_usage.get("input_tokens", 0)) - output_tokens = int(token_usage.get("output_tokens", 0)) - except (TypeError, ValueError) as e: - logger.error(f"Error converting token counts to integers: {str(e)}") - input_tokens = 0 - output_tokens = 0 - - # Calculate costs with more precision - input_cost = (input_tokens / 1_000_000.0) * model_prices["input"] - output_cost = (output_tokens / 1_000_000.0) * model_prices["output"] - total_cost = input_cost + output_cost - - # Create detailed token usage breakdown - token_details = { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, - "model_name": model_name, - "input_price_per_million": model_prices["input"], - "output_price_per_million": model_prices["output"], - } - - # Add token details if available - if "input_token_details" in token_usage: - token_details["input_token_details"] = token_usage["input_token_details"] - if "output_token_details" in token_usage: - token_details["output_token_details"] = token_usage["output_token_details"] - - # Debug logging with more detail - logger.debug( - f"Cost calculation details: Model={model_name} | Input={input_tokens} tokens * ${model_prices['input']}/1M = ${input_cost:.6f} | Output={output_tokens} tokens * ${model_prices['output']}/1M = ${output_cost:.6f} | Total=${total_cost:.6f} | Token details={token_details}" - ) - - return { - "input_cost": round(input_cost, 6), - "output_cost": round(output_cost, 6), - "total_cost": round(total_cost, 6), - "currency": "USD", - "details": token_details, - } - - async def evaluate_and_vote_on_proposal( proposal_id: UUID, wallet_id: Optional[UUID] = None, + agent_id: Optional[UUID] = None, auto_vote: bool = True, confidence_threshold: float = 0.7, dao_id: Optional[UUID] = None, @@ -805,6 +821,7 @@ async def evaluate_and_vote_on_proposal( Args: proposal_id: The ID of the proposal to evaluate and vote on wallet_id: Optional wallet ID to use for voting + agent_id: Optional agent ID to use for retrieving prompts auto_vote: Whether to automatically vote based on the evaluation confidence_threshold: Minimum confidence score required to auto-vote (0.0-1.0) dao_id: Optional DAO ID to explicitly pass to the workflow @@ -817,183 +834,77 @@ async def evaluate_and_vote_on_proposal( ) try: - # Get proposal data directly from the database - proposal_data = backend.get_proposal(proposal_id) - if not proposal_data: - error_msg = f"Proposal {proposal_id} not found in database" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - # Decode parameters if they exist - decoded_parameters = decode_hex_parameters(proposal_data.parameters) - if decoded_parameters: - logger.debug( - f"Decoded proposal parameters: length={len(decoded_parameters) if decoded_parameters else 0}" - ) - - # Convert proposal data to dictionary and ensure parameters exist - proposal_dict = { - "proposal_id": proposal_data.proposal_id, - "parameters": decoded_parameters - or proposal_data.parameters, # Use decoded if available - "action": proposal_data.action, - "caller": proposal_data.caller, - "contract_principal": proposal_data.contract_principal, - "creator": proposal_data.creator, - "created_at_block": proposal_data.created_at_block, - "end_block": proposal_data.end_block, - "start_block": proposal_data.start_block, - "liquid_tokens": proposal_data.liquid_tokens, - "type": proposal_data.type, # Add proposal type - "proposal_contract": proposal_data.proposal_contract, # Add proposal contract for core proposals - } - - # For action proposals, parameters are required - if proposal_data.type == ProposalType.ACTION and not proposal_dict.get( - "parameters" - ): - error_msg = "No parameters found in action proposal data" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - # For core proposals, proposal_contract is required - if proposal_data.type == ProposalType.CORE and not proposal_dict.get( - "proposal_contract" - ): - error_msg = "No proposal contract found in core proposal data" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - # Get DAO info based on provided dao_id or from proposal - dao_info = None - if dao_id: - logger.debug( - f"Using provided DAO ID: {dao_id} | Found={dao_info is not None}" - ) - dao_info = backend.get_dao(dao_id) - if not dao_info: - logger.warning( - f"Provided DAO ID {dao_id} not found, falling back to proposal's DAO ID" - ) - - # If dao_info is still None, try to get it from proposal's dao_id - if not dao_info and proposal_data.dao_id: - logger.debug( - f"Using proposal's DAO ID: {proposal_data.dao_id} | Found={dao_info is not None}" - ) - dao_info = backend.get_dao(proposal_data.dao_id) - - if not dao_info: - error_msg = "Could not find DAO information" - logger.error(error_msg) - return {"success": False, "error": error_msg} - - # Get the treasury extension for the DAO - treasury_extension = None - try: - treasury_extensions = backend.list_extensions( - ExtensionFilter(dao_id=dao_info.id, type="EXTENSIONS_TREASURY") - ) - if treasury_extensions: - treasury_extension = treasury_extensions[0] - logger.debug( - f"Found treasury extension: contract_principal={treasury_extension.contract_principal}" - ) - - # Get treasury balance from Hiro API - hiro_api = HiroApi() - treasury_balance = hiro_api.get_address_balance( - treasury_extension.contract_principal - ) - logger.debug(f"Treasury balance retrieved: balance={treasury_balance}") - else: - logger.warning(f"No treasury extension found for DAO {dao_info.id}") - treasury_balance = None - except Exception as e: - logger.error(f"Failed to get treasury balance: {str(e)}", exc_info=True) - treasury_balance = None - - logger.debug( - f"Processing proposal for DAO: {dao_info.name} (ID: {dao_info.id})" - ) - - # Get the wallet and agent information if available - agent_id = None - if wallet_id: + # Determine effective agent ID + effective_agent_id = agent_id + if not effective_agent_id and wallet_id: wallet = backend.get_wallet(wallet_id) if wallet and wallet.agent_id: - agent_id = wallet.agent_id - logger.debug(f"Using agent ID {agent_id} for wallet {wallet_id}") + effective_agent_id = wallet.agent_id + logger.debug( + f"Using agent ID {effective_agent_id} from wallet {wallet_id}" + ) - # Get agent prompts - agent_prompts = [] + # Fetch the primary prompt to determine model and temperature settings + # Note: Actual prompt text fetching happens inside the workflow now. model_name = "gpt-4.1" # Default model temperature = 0.1 # Default temperature - try: - logger.debug( - f"Fetching prompts for agent_id={agent_id}, dao_id={proposal_data.dao_id}" - ) - prompts = backend.list_prompts( - PromptFilter( - agent_id=agent_id, - dao_id=proposal_data.dao_id, - is_active=True, - ) - ) - logger.debug(f"Retrieved prompts: {prompts}") - - # Store the full Prompt objects and get model settings from first prompt - agent_prompts = prompts - if agent_prompts: - first_prompt = agent_prompts[0] - model_name = first_prompt.model or model_name - temperature = ( - first_prompt.temperature - if first_prompt.temperature is not None - else temperature - ) - logger.debug( - f"Using model configuration: {model_name} (temperature={temperature})" + if effective_agent_id: + try: + # We only need one active prompt to get settings + prompts = backend.list_prompts( + PromptFilter( + agent_id=effective_agent_id, + dao_id=dao_id, # Assuming dao_id is available, might need refinement + is_active=True, + limit=1, + ) ) - else: - logger.warning( - f"No active prompts found for agent_id={agent_id}, dao_id={proposal_data.dao_id}" + if prompts: + first_prompt = prompts[0] + model_name = first_prompt.model or model_name + temperature = ( + first_prompt.temperature + if first_prompt.temperature is not None + else temperature + ) + logger.debug( + f"Using model settings from agent {effective_agent_id}: {model_name} (temp={temperature})" + ) + else: + logger.warning( + f"No active prompts found for agent {effective_agent_id} to determine settings." + ) + except Exception as e: + logger.error( + f"Failed to get agent prompt settings: {str(e)}", exc_info=True ) - except Exception as e: - logger.error(f"Failed to get agent prompts: {str(e)}", exc_info=True) - # Initialize state + # Initialize state (minimal initial data) state = { - "action_proposals_contract": proposal_dict["contract_principal"], - "action_proposals_voting_extension": proposal_dict["action"], - "proposal_id": proposal_dict["proposal_id"], - "proposal_data": proposal_dict, - "dao_info": dao_info.model_dump() if dao_info else {}, - "treasury_balance": treasury_balance, - "agent_prompts": ( - [p.prompt_text for p in agent_prompts] if agent_prompts else [] - ), + "proposal_id": proposal_id, + "dao_id": dao_id, # Pass DAO ID to the workflow + "agent_id": effective_agent_id, # Pass Agent ID for prompt loading + "wallet_id": wallet_id, # Pass wallet ID for voting tool "approve": False, "confidence_score": 0.0, "reasoning": "", "vote_result": None, - "wallet_id": wallet_id, "confidence_threshold": confidence_threshold, "auto_vote": auto_vote, "vector_results": None, "recent_tweets": None, "web_search_results": None, "token_usage": None, - "model_info": { - "name": "unknown", - "temperature": None, - }, + "model_info": None, + "plan": None, + "planning_token_usage": None, + "web_search_token_usage": None, + "evaluation_token_usage": None, + "evaluation_model_info": None, + "planning_model_info": None, + "web_search_model_info": None, } - logger.debug( - f"Agent prompts count: {len(state['agent_prompts'] or [])} | Has prompts: {bool(state['agent_prompts'])}" - ) - # Create and run workflow with model settings from prompt workflow = ProposalEvaluationWorkflow( model_name=model_name, temperature=temperature @@ -1042,29 +953,86 @@ async def evaluate_and_vote_on_proposal( "recent_tweets": result["recent_tweets"], "web_search_results": result["web_search_results"], "treasury_balance": result.get("treasury_balance"), - "token_usage": result.get( - "token_usage", - {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + "planning_token_usage": result.get("planning_token_usage"), + "web_search_token_usage": result.get("web_search_token_usage"), + "evaluation_token_usage": result.get("evaluation_token_usage"), + "evaluation_model_info": result.get("evaluation_model_info"), + "planning_model_info": result.get("planning_model_info"), + "web_search_model_info": result.get("web_search_model_info"), + } + + # --- Aggregate Token Usage and Calculate Costs --- # + total_token_usage_by_model = {} + total_cost_by_model = {} + total_overall_cost = 0.0 + + steps = [ + ( + "planning", + result.get("planning_token_usage"), + result.get("planning_model_info"), ), - "model_info": result.get( - "model_info", {"name": "unknown", "temperature": None} + ( + "web_search", + result.get("web_search_token_usage"), + result.get("web_search_model_info"), ), - } + ( + "evaluation", + result.get("evaluation_token_usage"), + result.get("evaluation_model_info"), + ), + ] - # Calculate token costs - token_costs = calculate_token_cost( - final_result["token_usage"], final_result["model_info"]["name"] - ) - final_result["token_costs"] = token_costs + for step_name, usage, model_info in steps: + if usage and model_info and model_info.get("name") != "unknown": + model_name = model_info["name"] - # For the example token usage shown: - # Input: 7425 tokens * ($2.50/1M) = $0.0186 - # Output: 312 tokens * ($10.00/1M) = $0.0031 - # Total: $0.0217 + # Aggregate usage per model + if model_name not in total_token_usage_by_model: + total_token_usage_by_model[model_name] = { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + total_token_usage_by_model[model_name]["input_tokens"] += usage.get( + "input_tokens", 0 + ) + total_token_usage_by_model[model_name]["output_tokens"] += usage.get( + "output_tokens", 0 + ) + total_token_usage_by_model[model_name]["total_tokens"] += usage.get( + "total_tokens", 0 + ) + + # Calculate cost for this step/model + step_cost = calculate_token_cost(usage, model_name) + + # Aggregate cost per model + if model_name not in total_cost_by_model: + total_cost_by_model[model_name] = 0.0 + total_cost_by_model[model_name] += step_cost["total_cost"] + total_overall_cost += step_cost["total_cost"] + else: + logger.warning( + f"Skipping cost calculation for step '{step_name}' due to missing usage or model info." + ) + final_result["total_token_usage_by_model"] = total_token_usage_by_model + final_result["total_cost_by_model"] = total_cost_by_model + final_result["total_overall_cost"] = total_overall_cost + # --- End Aggregation --- # + + # Updated Logging logger.debug( - f"Proposal evaluation completed: Success={final_result['success']} | Decision={'APPROVE' if final_result['evaluation']['approve'] else 'REJECT'} | Confidence={final_result['evaluation']['confidence_score']:.2f} | Auto-voted={final_result['auto_voted']} | Transaction={tx_id or 'None'} | Model={final_result['model_info']['name']} | Token Usage={final_result['token_usage']} | Cost (USD)=${token_costs['total_cost']:.4f} (Input=${token_costs['input_cost']:.4f} for {token_costs['details']['input_tokens']} tokens, Output=${token_costs['output_cost']:.4f} for {token_costs['details']['output_tokens']} tokens)" + f"Proposal evaluation completed: Success={final_result['success']} | " + f"Decision={'APPROVE' if final_result['evaluation']['approve'] else 'REJECT'} | " + f"Confidence={final_result['evaluation']['confidence_score']:.2f} | " + f"Auto-voted={final_result['auto_voted']} | Transaction={tx_id or 'None'} | " + f"Total Cost (USD)=${total_overall_cost:.4f}" ) + logger.debug(f"Cost Breakdown: {total_cost_by_model}") + logger.debug(f"Token Usage Breakdown: {total_token_usage_by_model}") logger.debug(f"Full evaluation result: {final_result}") return final_result @@ -1080,21 +1048,34 @@ async def evaluate_and_vote_on_proposal( async def evaluate_proposal_only( proposal_id: UUID, wallet_id: Optional[UUID] = None, + agent_id: Optional[UUID] = None, + dao_id: Optional[UUID] = None, ) -> Dict: """Evaluate a proposal without voting. Args: proposal_id: The ID of the proposal to evaluate wallet_id: Optional wallet ID to use for retrieving proposal data + agent_id: Optional agent ID associated with the evaluation + dao_id: Optional DAO ID associated with the proposal Returns: Dictionary containing the evaluation results """ logger.debug(f"Starting proposal-only evaluation: proposal_id={proposal_id}") + # Determine effective agent ID (same logic as evaluate_and_vote) + effective_agent_id = agent_id + if not effective_agent_id and wallet_id: + wallet = backend.get_wallet(wallet_id) + if wallet and wallet.agent_id: + effective_agent_id = wallet.agent_id + result = await evaluate_and_vote_on_proposal( proposal_id=proposal_id, wallet_id=wallet_id, + agent_id=effective_agent_id, + dao_id=dao_id, auto_vote=False, ) diff --git a/services/workflows/react.py b/services/workflows/react.py deleted file mode 100644 index d5742f90..00000000 --- a/services/workflows/react.py +++ /dev/null @@ -1,590 +0,0 @@ -"""ReAct workflow functionality.""" - -import asyncio -import datetime -import uuid -from dataclasses import dataclass -from typing import ( - Annotated, - Any, - AsyncGenerator, - Dict, - List, - Optional, - TypedDict, - Union, -) - -from langchain.callbacks.base import BaseCallbackHandler -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from langchain_core.outputs import LLMResult -from langchain_openai import ChatOpenAI -from langgraph.graph import END, START, StateGraph -from langgraph.graph.message import add_messages -from langgraph.prebuilt import ToolNode - -from lib.logger import configure_logger -from services.workflows.base import BaseWorkflow, ExecutionError, StreamingError - -# Remove this import to avoid circular dependencies -# from services.workflows.workflow_service import BaseWorkflowService, WorkflowBuilder - -logger = configure_logger(__name__) - - -@dataclass -class MessageContent: - """Data class for message content""" - - role: str - content: str - tool_calls: Optional[List[Dict]] = None - - @classmethod - def from_dict(cls, data: Dict) -> "MessageContent": - """Create MessageContent from dictionary""" - return cls( - role=data.get("role", ""), - content=data.get("content", ""), - tool_calls=data.get("tool_calls"), - ) - - -class MessageProcessor: - """Processor for messages""" - - @staticmethod - def extract_filtered_content(history: List[Dict]) -> List[Dict]: - """Extract and filter content from message history.""" - logger.debug( - f"Starting content extraction from history with {len(history)} messages" - ) - filtered_content = [] - - for message in history: - logger.debug(f"Processing message type: {message.get('role')}") - if message.get("role") in ["user", "assistant"]: - filtered_content.append(MessageContent.from_dict(message).__dict__) - - logger.debug( - f"Finished filtering content, extracted {len(filtered_content)} messages" - ) - return filtered_content - - @staticmethod - def convert_to_langchain_messages( - filtered_content: List[Dict], - current_input: str, - persona: Optional[str] = None, - ) -> List[Union[SystemMessage, HumanMessage, AIMessage]]: - """Convert filtered content to LangChain message format.""" - messages = [] - - # Add decisiveness instruction - decisiveness_instruction = "Be decisive and action-oriented. When the user requests something, execute it immediately without asking for confirmation." - - if persona: - logger.debug("Adding persona message with decisiveness instruction") - # Add the decisiveness instruction to the persona - enhanced_persona = f"{persona}\n\n{decisiveness_instruction}" - messages.append(SystemMessage(content=enhanced_persona)) - else: - # If no persona, add the decisiveness instruction as a system message - logger.debug("Adding decisiveness instruction as system message") - messages.append(SystemMessage(content=decisiveness_instruction)) - - for msg in filtered_content: - if msg["role"] == "user": - messages.append(HumanMessage(content=msg["content"])) - else: - content = msg.get("content") or "" - if msg.get("tool_calls"): - messages.append( - AIMessage(content=content, tool_calls=msg["tool_calls"]) - ) - else: - messages.append(AIMessage(content=content)) - - messages.append(HumanMessage(content=current_input)) - logger.debug(f"Prepared message chain with {len(messages)} total messages") - return messages - - -class StreamingCallbackHandler(BaseCallbackHandler): - """Handle callbacks from LangChain and stream results to a queue.""" - - def __init__( - self, - queue: asyncio.Queue, - on_llm_new_token: Optional[callable] = None, - on_llm_end: Optional[callable] = None, - ): - """Initialize the callback handler with a queue.""" - self.queue = queue - self.tool_states = {} # Store tool states by invocation ID - self.tool_inputs = {} # Store tool inputs by invocation ID - self.active_tools = {} # Track active tools by name for fallback - self.custom_on_llm_new_token = on_llm_new_token - self.custom_on_llm_end = on_llm_end - # Track the current execution phase - self.current_phase = "processing" # Default phase is processing - - def _ensure_loop(self) -> asyncio.AbstractEventLoop: - """Get the current event loop or create a new one if necessary.""" - try: - loop = asyncio.get_running_loop() - return loop - except RuntimeError: - logger.debug("No running event loop found. Creating a new one.") - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop - - async def _async_put_to_queue(self, item: Dict) -> None: - """Put an item in the queue asynchronously.""" - try: - await self.queue.put(item) - except Exception as e: - logger.error(f"Failed to put item in queue: {str(e)}") - raise StreamingError(f"Queue operation failed: {str(e)}") - - def _put_to_queue(self, item: Dict) -> None: - """Put an item in the queue, handling event loop considerations.""" - try: - loop = self._ensure_loop() - if loop.is_running(): - future = asyncio.run_coroutine_threadsafe( - self._async_put_to_queue(item), loop - ) - future.result() - else: - loop.run_until_complete(self._async_put_to_queue(item)) - except Exception as e: - logger.error(f"Failed to put item in queue: {str(e)}") - raise StreamingError(f"Queue operation failed: {str(e)}") - - def _get_tool_info( - self, invocation_id: Optional[str], tool_name: Optional[str] = None - ) -> Optional[tuple]: - """Get tool information using either invocation_id or tool_name. - - Returns: - Optional[tuple]: (tool_name, tool_input, invocation_id) if found, None otherwise - """ - if invocation_id and invocation_id in self.tool_states: - return ( - self.tool_states[invocation_id], - self.tool_inputs.get(invocation_id, ""), - invocation_id, - ) - elif tool_name and tool_name in self.active_tools: - active_info = self.active_tools[tool_name] - return (tool_name, active_info["input"], active_info["invocation_id"]) - return None - - async def process_step( - self, content: str, role: str = "assistant", thought: Optional[str] = None - ) -> None: - """Process a planning step and queue it with the planning status. - - Args: - content: The planning step content - role: The role associated with the step (usually assistant) - thought: Optional thought process notes - """ - try: - # Create step message with explicit planning status - current_time = datetime.datetime.now().isoformat() - step_message = { - "type": "step", - "status": "planning", # Explicitly mark as planning phase - "content": content, - "role": role, - "thought": thought - or "Planning Phase", # Default to Planning Phase if thought is not provided - "created_at": current_time, - "planning_only": True, # Mark this content as planning-only to prevent duplication - } - - logger.debug(f"Queuing planning step message with length: {len(content)}") - await self._async_put_to_queue(step_message) - except Exception as e: - logger.error(f"Failed to process planning step: {str(e)}") - raise StreamingError(f"Planning step processing failed: {str(e)}") - - def on_tool_start(self, serialized: Dict, input_str: str, **kwargs) -> None: - """Run when tool starts running.""" - tool_name = serialized.get("name") - if not tool_name: - logger.warning("Tool start called without tool name") - return - - invocation_id = kwargs.get("invocation_id", str(uuid.uuid4())) - - # Store in both tracking systems - self.tool_states[invocation_id] = tool_name - self.tool_inputs[invocation_id] = input_str - self.active_tools[tool_name] = { - "invocation_id": invocation_id, - "input": input_str, - "start_time": datetime.datetime.now(), - } - - logger.info( - f"Tool started: {tool_name} (ID: {invocation_id}) with input: {input_str[:100]}..." - ) - - def on_tool_end(self, output: str, **kwargs) -> None: - """Run when tool ends running.""" - invocation_id = kwargs.get("invocation_id") - tool_name = kwargs.get("name") # Try to get tool name from kwargs - - # Try to get tool info from either source - tool_info = self._get_tool_info(invocation_id, tool_name) - - if tool_info: - tool_name, tool_input, used_invocation_id = tool_info - if hasattr(output, "content"): - output = output.content - - self._put_to_queue( - { - "type": "tool", - "tool": tool_name, - "input": tool_input, - "output": str(output), - "status": "processing", # Use "processing" status for tool end - "created_at": datetime.datetime.now().isoformat(), - } - ) - logger.info( - f"Tool {tool_name} (ID: {used_invocation_id}) completed with output length: {len(str(output))}" - ) - - # Clean up tracking - if used_invocation_id in self.tool_states: - del self.tool_states[used_invocation_id] - del self.tool_inputs[used_invocation_id] - if tool_name in self.active_tools: - del self.active_tools[tool_name] - else: - logger.warning( - f"Tool end called with unknown invocation ID: {invocation_id} and tool name: {tool_name}" - ) - - def on_tool_error(self, error: Exception, **kwargs) -> None: - """Run when tool errors.""" - invocation_id = kwargs.get("invocation_id") - tool_name = kwargs.get("name") # Try to get tool name from kwargs - - # Try to get tool info from either source - tool_info = self._get_tool_info(invocation_id, tool_name) - - if tool_info: - tool_name, tool_input, used_invocation_id = tool_info - self._put_to_queue( - { - "type": "tool", - "tool": tool_name, - "input": tool_input, - "output": f"Error: {str(error)}", - "status": "error", - "created_at": datetime.datetime.now().isoformat(), - } - ) - logger.error( - f"Tool {tool_name} (ID: {used_invocation_id}) failed with error: {str(error)}", - exc_info=True, - ) - - # Clean up tracking - if used_invocation_id in self.tool_states: - del self.tool_states[used_invocation_id] - del self.tool_inputs[used_invocation_id] - if tool_name in self.active_tools: - del self.active_tools[tool_name] - else: - logger.warning( - f"Tool error called with unknown invocation ID: {invocation_id} and tool name: {tool_name}" - ) - - def on_llm_start(self, *args, **kwargs) -> None: - """Run when LLM starts running.""" - logger.info("LLM processing started") - - def on_llm_new_token(self, token: str, **kwargs) -> None: - """Run on new token.""" - # Check if we have planning_only in the kwargs - planning_only = kwargs.get("planning_only", False) - - # Handle custom token processing if provided - if self.custom_on_llm_new_token: - try: - # Check if it's a coroutine function and handle accordingly - if asyncio.iscoroutinefunction(self.custom_on_llm_new_token): - # For coroutines, we need to schedule it to run without awaiting - loop = self._ensure_loop() - # Create the coroutine object without calling it - coro = self.custom_on_llm_new_token(token, **kwargs) - # Schedule it to run in the event loop - asyncio.run_coroutine_threadsafe(coro, loop) - else: - # Regular function call - self.custom_on_llm_new_token(token, **kwargs) - except Exception as e: - logger.error(f"Error in custom token handler: {str(e)}", exc_info=True) - - # Log token information with phase information - phase = "planning" if planning_only else "processing" - logger.debug(f"Received new token (length: {len(token)}, phase: {phase})") - - def on_llm_end(self, response: LLMResult, **kwargs) -> None: - """Run when LLM ends running.""" - logger.info("LLM processing completed") - - # Queue an end message with complete status - try: - self._put_to_queue( - { - "type": "token", - "status": "complete", - "content": "", - "created_at": datetime.datetime.now().isoformat(), - } - ) - except Exception as e: - logger.error(f"Failed to queue completion message: {str(e)}") - - # Handle custom end processing if provided - if self.custom_on_llm_end: - try: - # Check if it's a coroutine function and handle accordingly - if asyncio.iscoroutinefunction(self.custom_on_llm_end): - # For coroutines, we need to schedule it to run without awaiting - loop = self._ensure_loop() - # Create the coroutine object without calling it - coro = self.custom_on_llm_end(response, **kwargs) - # Schedule it to run in the event loop - asyncio.run_coroutine_threadsafe(coro, loop) - else: - # Regular function call - self.custom_on_llm_end(response, **kwargs) - except Exception as e: - logger.error(f"Error in custom end handler: {str(e)}", exc_info=True) - - def on_llm_error(self, error: Exception, **kwargs) -> None: - """Run when LLM errors.""" - logger.error(f"LLM error occurred: {str(error)}", exc_info=True) - - # Send error status - try: - self._put_to_queue( - { - "type": "token", - "status": "error", - "content": f"Error: {str(error)}", - "created_at": datetime.datetime.now().isoformat(), - } - ) - except Exception: - pass # Don't raise another error if this fails - - raise ExecutionError("LLM processing failed", {"error": str(error)}) - - -class ReactState(TypedDict): - """State for the ReAct workflow.""" - - messages: Annotated[list, add_messages] - - -class ReactWorkflow(BaseWorkflow[ReactState]): - """ReAct workflow implementation.""" - - def __init__( - self, - callback_handler: StreamingCallbackHandler, - tools: List[Any], - **kwargs, - ): - super().__init__(**kwargs) - self.callback_handler = callback_handler - self.tools = tools - # Create a new LLM instance with the callback handler - self.llm = self.create_llm_with_callbacks([callback_handler]).bind_tools(tools) - self.required_fields = ["messages"] - - def _create_prompt(self) -> None: - """Not used in ReAct workflow.""" - pass - - def _create_graph(self) -> StateGraph: - """Create the ReAct workflow graph.""" - tool_node = ToolNode(self.tools) - - def should_continue(state: ReactState) -> str: - messages = state["messages"] - last_message = messages[-1] - result = "tools" if last_message.tool_calls else END - logger.debug(f"Continue decision: {result}") - return result - - def call_model(state: ReactState) -> Dict: - logger.debug("Calling model with current state") - messages = state["messages"] - response = self.llm.invoke(messages) - logger.debug("Received model response") - return {"messages": [response]} - - workflow = StateGraph(ReactState) - workflow.add_node("agent", call_model) - workflow.add_node("tools", tool_node) - workflow.add_edge(START, "agent") - workflow.add_conditional_edges("agent", should_continue) - workflow.add_edge("tools", "agent") - - return workflow - - -class LangGraphService: - """Service for executing LangGraph operations""" - - def __init__(self): - """Initialize the service.""" - self.message_processor = MessageProcessor() - - async def _execute_stream_impl( - self, - messages: List[Union[SystemMessage, HumanMessage, AIMessage]], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - **kwargs, - ) -> AsyncGenerator[Dict, None]: - """Execute a ReAct stream using LangGraph. - - Args: - messages: Processed messages ready for the LLM - input_str: Current user input - persona: Optional persona to use - tools_map: Optional tools to use - **kwargs: Additional arguments - - Returns: - Async generator of result chunks - """ - try: - # Import here to avoid circular dependencies - from services.workflows.workflow_service import ( - BaseWorkflowService, - WorkflowBuilder, - ) - - # Setup queue and callbacks - callback_queue = asyncio.Queue() - loop = asyncio.get_running_loop() - - # Setup callback handler - callback_handler = self.setup_callback_handler(callback_queue, loop) - - # Create workflow using builder pattern - workflow = ( - WorkflowBuilder(ReactWorkflow) - .with_callback_handler(callback_handler) - .with_tools(list(tools_map.values()) if tools_map else []) - .build() - ) - - # Create graph and compile - graph = workflow._create_graph() - runnable = graph.compile() - - # Execute workflow with callbacks config - config = {"callbacks": [callback_handler]} - task = asyncio.create_task( - runnable.ainvoke({"messages": messages}, config=config) - ) - - # Stream results - async for chunk in self.stream_task_results(task, callback_queue): - yield chunk - - except Exception as e: - logger.error(f"Failed to execute ReAct stream: {str(e)}", exc_info=True) - raise ExecutionError(f"ReAct stream execution failed: {str(e)}") - - def setup_callback_handler(self, queue, loop): - # Import here to avoid circular dependencies - from services.workflows.workflow_service import BaseWorkflowService - - # Use the static method instead of instantiating BaseWorkflowService - return BaseWorkflowService.create_callback_handler(queue, loop) - - async def stream_task_results(self, task, queue): - # Import here to avoid circular dependencies - from services.workflows.workflow_service import BaseWorkflowService - - # Use the static method instead of instantiating BaseWorkflowService - async for chunk in BaseWorkflowService.stream_results_from_task( - task=task, callback_queue=queue, logger_name=self.__class__.__name__ - ): - yield chunk - - # Keep the old method for backward compatibility - async def execute_react_stream( - self, - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - ) -> AsyncGenerator[Dict, None]: - """Execute a ReAct stream using LangGraph.""" - # Process messages for backward compatibility - filtered_content = self.message_processor.extract_filtered_content(history) - messages = self.message_processor.convert_to_langchain_messages( - filtered_content, input_str, persona - ) - - # Call the new implementation - async for chunk in self._execute_stream_impl( - messages=messages, - input_str=input_str, - persona=persona, - tools_map=tools_map, - ): - yield chunk - - # Add execute_stream as alias for consistency across services - async def execute_stream( - self, - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - **kwargs, - ) -> AsyncGenerator[Dict, None]: - """Execute a workflow stream. - - This is an alias for execute_react_stream to maintain consistent API - across different workflow services. - """ - async for chunk in self.execute_react_stream( - history=history, - input_str=input_str, - persona=persona, - tools_map=tools_map, - ): - yield chunk - - -# Facade function for backward compatibility -async def execute_langgraph_stream( - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, -) -> AsyncGenerator[Dict, None]: - """Execute a ReAct stream using LangGraph with optional persona.""" - service = LangGraphService() - async for chunk in service.execute_stream(history, input_str, persona, tools_map): - yield chunk diff --git a/services/workflows/utils.py b/services/workflows/utils.py new file mode 100644 index 00000000..fc1fb815 --- /dev/null +++ b/services/workflows/utils.py @@ -0,0 +1,117 @@ +"""Workflow utility functions.""" + +import binascii +import logging +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + + +def decode_hex_parameters(hex_string: Optional[str]) -> Optional[str]: + """Decodes a hexadecimal-encoded string if valid. + + Args: + hex_string: The hexadecimal string to decode. + + Returns: + The decoded string, or None if decoding fails. + """ + if not hex_string: + return None + if hex_string.startswith("0x"): + hex_string = hex_string[2:] # Remove "0x" prefix + try: + decoded_bytes = binascii.unhexlify(hex_string) + decoded_string = decoded_bytes.decode( + "utf-8", errors="ignore" + ) # Decode as UTF-8 + logger.debug(f"Successfully decoded hex string: {hex_string[:20]}...") + return decoded_string + except (binascii.Error, UnicodeDecodeError) as e: + logger.warning(f"Failed to decode hex string: {str(e)}") + return None # Return None if decoding fails + + +# Model pricing data (move this to a config or constants file later if needed) +MODEL_PRICES = { + "gpt-4o": { + "input": 2.50, # $2.50 per million input tokens + "output": 10.00, # $10.00 per million output tokens + }, + "gpt-4.1": { + "input": 2.00, # $2.00 per million input tokens + "output": 8.00, # $8.00 per million output tokens + }, + "gpt-4.1-mini": { + "input": 0.40, # $0.40 per million input tokens + "output": 1.60, # $1.60 per million output tokens + }, + "gpt-4.1-nano": { + "input": 0.10, # $0.10 per million input tokens + "output": 0.40, # $0.40 per million output tokens + }, + # Default to gpt-4.1 pricing if model not found + "default": { + "input": 2.00, + "output": 8.00, + }, +} + + +def calculate_token_cost( + token_usage: Dict[str, int], model_name: str +) -> Dict[str, float]: + """Calculate the cost of token usage based on current pricing. + + Args: + token_usage: Dictionary containing input_tokens and output_tokens + model_name: Name of the model used + + Returns: + Dictionary containing cost breakdown and total cost + """ + # Get pricing for the model, default to gpt-4.1 pricing if not found + model_prices = MODEL_PRICES.get(model_name.lower(), MODEL_PRICES["default"]) + + # Extract token counts, ensuring we get integers and handle None values + try: + input_tokens = int(token_usage.get("input_tokens", 0)) + output_tokens = int(token_usage.get("output_tokens", 0)) + except (TypeError, ValueError) as e: + logger.error(f"Error converting token counts to integers: {str(e)}") + input_tokens = 0 + output_tokens = 0 + + # Calculate costs with more precision + input_cost = (input_tokens / 1_000_000.0) * model_prices["input"] + output_cost = (output_tokens / 1_000_000.0) * model_prices["output"] + total_cost = input_cost + output_cost + + # Create detailed token usage breakdown + token_details = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "model_name": model_name, + "input_price_per_million": model_prices["input"], + "output_price_per_million": model_prices["output"], + } + + # Add token details if available + if "input_token_details" in token_usage: + token_details["input_token_details"] = token_usage["input_token_details"] + if "output_token_details" in token_usage: + token_details["output_token_details"] = token_usage["output_token_details"] + + # Debug logging with more detail + logger.debug( + f"Cost calculation details: Model={model_name} | Input={input_tokens} tokens * ${model_prices['input']}/1M = ${input_cost:.6f} | Output={output_tokens} tokens * ${model_prices['output']}/1M = ${output_cost:.6f} | Total=${total_cost:.6f} | Token details={token_details}" + ) + + return { + "input_cost": round(input_cost, 6), + "output_cost": round(output_cost, 6), + "total_cost": round(total_cost, 6), + "currency": "USD", + "details": token_details, + } diff --git a/services/workflows/vector_mixin.py b/services/workflows/vector_mixin.py new file mode 100644 index 00000000..f6aaa750 --- /dev/null +++ b/services/workflows/vector_mixin.py @@ -0,0 +1,180 @@ +"""Vector retrieval mixin and vector document utilities for workflows.""" + +from typing import Any, Dict, List, Optional + +from langchain_core.documents import Document +from langchain_openai import OpenAIEmbeddings +from langgraph.graph import StateGraph + +from backend.factory import backend +from lib.logger import configure_logger +from services.workflows.base import BaseWorkflowMixin + +logger = configure_logger(__name__) + + +class VectorRetrievalCapability(BaseWorkflowMixin): + """Mixin that adds vector retrieval capabilities to a workflow.""" + + def __init__(self, *args, **kwargs): + """Initialize the vector retrieval capability.""" + super().__init__(*args, **kwargs) if hasattr(super(), "__init__") else None + self._init_vector_retrieval() + + def _init_vector_retrieval(self) -> None: + """Initialize vector retrieval attributes if not already initialized.""" + if not hasattr(self, "collection_names"): + self.collection_names = ["knowledge_collection", "dao_collection"] + if not hasattr(self, "embeddings"): + self.embeddings = OpenAIEmbeddings() + if not hasattr(self, "vector_results_cache"): + self.vector_results_cache = {} + + async def retrieve_from_vector_store(self, query: str, **kwargs) -> List[Document]: + """Retrieve relevant documents from multiple vector stores. + + Args: + query: The query to search for + **kwargs: Additional arguments (collection_name, embeddings, etc.) + + Returns: + List of retrieved documents + """ + try: + self._init_vector_retrieval() + if query in self.vector_results_cache: + logger.debug(f"Using cached vector results for query: {query}") + return self.vector_results_cache[query] + all_documents = [] + limit_per_collection = kwargs.get("limit", 4) + logger.debug( + f"Searching vector store: query={query} | limit_per_collection={limit_per_collection}" + ) + for collection_name in self.collection_names: + try: + vector_results = await backend.query_vectors( + collection_name=collection_name, + query_text=query, + limit=limit_per_collection, + embeddings=self.embeddings, + ) + documents = [ + Document( + page_content=doc.get("page_content", ""), + metadata={ + **doc.get("metadata", {}), + "collection_source": collection_name, + }, + ) + for doc in vector_results + ] + all_documents.extend(documents) + logger.debug( + f"Retrieved {len(documents)} documents from collection {collection_name}" + ) + except Exception as e: + logger.error( + f"Failed to retrieve from collection {collection_name}: {str(e)}", + exc_info=True, + ) + continue + logger.debug( + f"Retrieved total of {len(all_documents)} documents from all collections" + ) + self.vector_results_cache[query] = all_documents + return all_documents + except Exception as e: + logger.error(f"Vector store retrieval failed: {str(e)}", exc_info=True) + return [] + + def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: + """Integrate vector retrieval capability with a graph. + + This adds the vector retrieval capability to the graph by adding a node + that can perform vector searches when needed. + + Args: + graph: The graph to integrate with + **kwargs: Additional arguments specific to vector retrieval including: + - collection_names: List of collection names to search + - limit_per_collection: Number of results per collection + """ + graph.add_node("vector_search", self.retrieve_from_vector_store) + if "process_vector_results" not in graph.nodes: + graph.add_node("process_vector_results", self._process_vector_results) + graph.add_edge("vector_search", "process_vector_results") + + async def _process_vector_results( + self, vector_results: List[Document], **kwargs + ) -> Dict[str, Any]: + """Process vector search results. + + Args: + vector_results: Results from vector search + **kwargs: Additional processing arguments + + Returns: + Processed results with metadata + """ + return { + "results": vector_results, + "metadata": { + "num_vector_results": len(vector_results), + "collection_sources": list( + set( + doc.metadata.get("collection_source", "unknown") + for doc in vector_results + ) + ), + }, + } + + +async def add_documents_to_vectors( + collection_name: str, + documents: List[Document], + embeddings: Optional[Any] = None, +) -> Dict[str, List[str]]: + """Add documents to a vector collection. + + Args: + collection_name: Name of the collection to add to + documents: List of LangChain Document objects + embeddings: Optional embeddings model to use + + Returns: + Dictionary mapping collection name to list of document IDs + """ + if embeddings is None: + raise ValueError( + "Embeddings model must be provided to add documents to vector store" + ) + collection_doc_ids = {} + try: + try: + backend.get_vector_collection(collection_name) + except Exception: + embed_dim = 1536 + if hasattr(embeddings, "embedding_dim"): + embed_dim = embeddings.embedding_dim + backend.create_vector_collection(collection_name, dimensions=embed_dim) + texts = [doc.page_content for doc in documents] + embedding_vectors = embeddings.embed_documents(texts) + docs_for_storage = [ + {"page_content": doc.page_content, "embedding": embedding_vectors[i]} + for i, doc in enumerate(documents) + ] + metadata_list = [doc.metadata for doc in documents] + ids = await backend.add_vectors( + collection_name=collection_name, + documents=docs_for_storage, + metadata=metadata_list, + ) + collection_doc_ids[collection_name] = ids + logger.info(f"Added {len(ids)} documents to collection {collection_name}") + except Exception as e: + logger.error( + f"Failed to add documents to collection {collection_name}: {str(e)}" + ) + collection_doc_ids[collection_name] = [] + return collection_doc_ids diff --git a/services/workflows/vector_react.py b/services/workflows/vector_react.py deleted file mode 100644 index aa55f95d..00000000 --- a/services/workflows/vector_react.py +++ /dev/null @@ -1,443 +0,0 @@ -"""Vector-enabled ReAct workflow functionality with Supabase Vecs integration.""" - -import asyncio -from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict, Union - -from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from langgraph.graph import END, START, StateGraph -from langgraph.prebuilt import ToolNode - -from backend.factory import backend -from lib.logger import configure_logger -from services.workflows.base import ( - BaseWorkflow, - ExecutionError, - VectorRetrievalCapability, -) -from services.workflows.react import ( - MessageProcessor, - ReactState, - StreamingCallbackHandler, -) - -# Remove this import to avoid circular dependencies -# from services.workflows.workflow_service import BaseWorkflowService, WorkflowBuilder - -logger = configure_logger(__name__) - - -class VectorRetrievalState(TypedDict): - """State for vector retrieval step.""" - - query: str - documents: List[Document] - - -class VectorReactState(ReactState): - """State for the Vector ReAct workflow, extending ReactState.""" - - vector_results: Optional[List[Document]] - - -class VectorReactWorkflow(BaseWorkflow[VectorReactState], VectorRetrievalCapability): - """ReAct workflow with vector store integration.""" - - def __init__( - self, - callback_handler: StreamingCallbackHandler, - tools: List[Any], - collection_names: Union[ - str, List[str] - ], # Modified to accept single or multiple collections - embeddings: Optional[Embeddings] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.callback_handler = callback_handler - self.tools = tools - # Convert single collection to list for consistency - self.collection_names = ( - [collection_names] - if isinstance(collection_names, str) - else collection_names - ) - self.embeddings = embeddings or OpenAIEmbeddings() - self.required_fields = ["messages"] - - # Create a new LLM instance with the callback handler - self.llm = self.create_llm_with_callbacks([callback_handler]).bind_tools(tools) - - def _create_prompt(self) -> None: - """Not used in VectorReact workflow.""" - pass - - async def retrieve_from_vector_store(self, query: str, **kwargs) -> List[Document]: - """Retrieve relevant documents from multiple vector stores. - - Args: - query: The query to search for - **kwargs: Additional arguments - - Returns: - List of retrieved documents - """ - try: - all_documents = [] - limit_per_collection = kwargs.get( - "limit", 4 - ) # Get 4 results from each collection - - # Query each collection and gather results - for collection_name in self.collection_names: - try: - # Query vectors using the backend - vector_results = await backend.query_vectors( - collection_name=collection_name, - query_text=query, - limit=limit_per_collection, - embeddings=self.embeddings, - ) - - # Convert to LangChain Documents and add collection source - documents = [ - Document( - page_content=doc.get("page_content", ""), - metadata={ - **doc.get("metadata", {}), - "collection_source": collection_name, - }, - ) - for doc in vector_results - ] - - all_documents.extend(documents) - logger.info( - f"Retrieved {len(documents)} documents from collection {collection_name}" - ) - except Exception as e: - logger.error( - f"Failed to retrieve from collection {collection_name}: {str(e)}" - ) - continue # Continue with other collections if one fails - - logger.info( - f"Retrieved total of {len(all_documents)} documents from all collections" - ) - return all_documents - except Exception as e: - logger.error(f"Vector store retrieval failed: {str(e)}") - return [] - - def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: - """Integrate vector retrieval capability with a graph. - - Args: - graph: The graph to integrate with - **kwargs: Additional arguments - """ - # Modify the graph to include vector retrieval - # This is specific to the VectorReactWorkflow - pass - - def _create_graph(self) -> StateGraph: - """Create the VectorReact workflow graph.""" - tool_node = ToolNode(self.tools) - - def should_continue(state: VectorReactState) -> str: - messages = state["messages"] - last_message = messages[-1] - result = "tools" if last_message.tool_calls else END - logger.debug(f"Continue decision: {result}") - return result - - async def retrieve_from_vector_store(state: VectorReactState) -> Dict: - """Retrieve relevant documents from vector store.""" - messages = state["messages"] - # Get the last user message - last_user_message = None - for message in reversed(messages): - if isinstance(message, HumanMessage): - last_user_message = message.content - break - - if not last_user_message: - logger.warning("No user message found for vector retrieval") - return {"vector_results": []} - - documents = await self.retrieve_from_vector_store(query=last_user_message) - return {"vector_results": documents} - - def call_model_with_context(state: VectorReactState) -> Dict: - """Call model with additional context from vector store.""" - messages = state["messages"] - vector_results = state.get("vector_results", []) - - # Add vector context to the system message if available - context_message = None - - if vector_results: - # Format the vector results into a context string - context_str = "\n\n".join([doc.page_content for doc in vector_results]) - context_message = SystemMessage( - content=f"Here is additional context that may be helpful:\n\n{context_str}\n\n" - "Use this context to inform your response if relevant." - ) - messages = [context_message] + messages - - logger.debug( - f"Calling model with {len(messages)} messages and " - f"{len(vector_results)} retrieved documents" - ) - - response = self.llm.invoke(messages) - return {"messages": [response]} - - workflow = StateGraph(VectorReactState) - workflow.add_node("vector_retrieval", retrieve_from_vector_store) - workflow.add_node("agent", call_model_with_context) - workflow.add_node("tools", tool_node) - - # Set up the execution flow - workflow.add_edge(START, "vector_retrieval") - workflow.add_edge("vector_retrieval", "agent") - workflow.add_conditional_edges("agent", should_continue) - workflow.add_edge("tools", "agent") - - return workflow - - -class VectorLangGraphService: - """Service for executing VectorReact LangGraph operations""" - - def __init__( - self, - collection_names: Union[ - str, List[str] - ], # Modified to accept single or multiple collections - embeddings: Optional[Embeddings] = None, - ): - # Import here to avoid circular imports - from services.workflows.react import MessageProcessor - - self.collection_names = collection_names - self.embeddings = embeddings or OpenAIEmbeddings() - self.message_processor = MessageProcessor() - - def setup_callback_handler(self, queue, loop): - # Import here to avoid circular dependencies - from services.workflows.workflow_service import BaseWorkflowService - - # Use the static method instead of instantiating BaseWorkflowService - return BaseWorkflowService.create_callback_handler(queue, loop) - - async def stream_task_results(self, task, queue): - # Import here to avoid circular dependencies - from services.workflows.workflow_service import BaseWorkflowService - - # Use the static method instead of instantiating BaseWorkflowService - async for chunk in BaseWorkflowService.stream_results_from_task( - task=task, callback_queue=queue, logger_name=self.__class__.__name__ - ): - yield chunk - - async def _execute_stream_impl( - self, - messages: List[Union[SystemMessage, HumanMessage, AIMessage]], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - **kwargs, - ) -> AsyncGenerator[Dict, None]: - """Execute a Vector React stream implementation. - - Args: - messages: Processed messages - input_str: Current user input - persona: Optional persona to use - tools_map: Optional tools to use - **kwargs: Additional arguments - - Returns: - Async generator of result chunks - """ - try: - # Import here to avoid circular dependencies - from services.workflows.workflow_service import WorkflowBuilder - - # Setup queue and callbacks - callback_queue = asyncio.Queue() - loop = asyncio.get_running_loop() - - # Setup callback handler - callback_handler = self.setup_callback_handler(callback_queue, loop) - - # Create workflow using builder pattern - workflow = ( - WorkflowBuilder(VectorReactWorkflow) - .with_callback_handler(callback_handler) - .with_tools(list(tools_map.values()) if tools_map else []) - .build( - collection_names=self.collection_names, - embeddings=self.embeddings, - ) - ) - - # Create graph and compile - graph = workflow._create_graph() - runnable = graph.compile() - - # Execute workflow with callbacks config - config = {"callbacks": [callback_handler]} - task = asyncio.create_task( - runnable.ainvoke( - {"messages": messages, "vector_results": []}, config=config - ) - ) - - # Stream results - async for chunk in self.stream_task_results(task, callback_queue): - yield chunk - - except Exception as e: - logger.error( - f"Failed to execute VectorReact stream: {str(e)}", exc_info=True - ) - raise ExecutionError(f"VectorReact stream execution failed: {str(e)}") - - # Add execute_stream method to maintain the same interface as BaseWorkflowService - async def execute_stream( - self, - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - **kwargs, - ) -> AsyncGenerator[Dict, None]: - """Execute a workflow stream. - - This processes the history and delegates to _execute_stream_impl. - """ - # Process messages - filtered_content = self.message_processor.extract_filtered_content(history) - messages = self.message_processor.convert_to_langchain_messages( - filtered_content, input_str, persona - ) - - # Call the implementation - async for chunk in self._execute_stream_impl( - messages=messages, - input_str=input_str, - persona=persona, - tools_map=tools_map, - **kwargs, - ): - yield chunk - - # Keep the old method for backward compatibility - async def execute_vector_react_stream( - self, - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - ) -> AsyncGenerator[Dict, None]: - """Execute a VectorReact stream using LangGraph.""" - # Call the new method - async for chunk in self.execute_stream(history, input_str, persona, tools_map): - yield chunk - - -# Helper function for adding documents to vector store -async def add_documents_to_vectors( - collection_name: str, # Modified to only accept a single collection - documents: List[Document], - embeddings: Optional[Embeddings] = None, -) -> Dict[str, List[str]]: - """Add documents to vector collection. - - Args: - collection_name: Name of the collection to add to - documents: List of LangChain Document objects - embeddings: Optional embeddings model to use - - Returns: - Dictionary mapping collection name to list of document IDs - """ - # Ensure embeddings model is provided - if embeddings is None: - raise ValueError( - "Embeddings model must be provided to add documents to vector store" - ) - - # Store document IDs for the collection - collection_doc_ids = {} - - try: - # Ensure collection exists - try: - backend.get_vector_collection(collection_name) - except Exception: - # Create collection if it doesn't exist - embed_dim = 1536 # Default for OpenAI embeddings - if hasattr(embeddings, "embedding_dim"): - embed_dim = embeddings.embedding_dim - backend.create_vector_collection(collection_name, dimensions=embed_dim) - - # Extract texts for embedding - texts = [doc.page_content for doc in documents] - - # Generate embeddings for the texts - embedding_vectors = embeddings.embed_documents(texts) - - # Prepare documents for storage with embeddings - docs_for_storage = [ - {"page_content": doc.page_content, "embedding": embedding_vectors[i]} - for i, doc in enumerate(documents) - ] - - # Prepare metadata - metadata_list = [doc.metadata for doc in documents] - - # Add to vector store - ids = await backend.add_vectors( - collection_name=collection_name, - documents=docs_for_storage, - metadata=metadata_list, - ) - - collection_doc_ids[collection_name] = ids - logger.info(f"Added {len(ids)} documents to collection {collection_name}") - - except Exception as e: - logger.error( - f"Failed to add documents to collection {collection_name}: {str(e)}" - ) - collection_doc_ids[collection_name] = [] - - return collection_doc_ids - - -# Facade function for backward compatibility -async def execute_vector_langgraph_stream( - collection_names: Union[ - str, List[str] - ], # Modified to accept single or multiple collections - history: List[Dict], - input_str: str, - persona: Optional[str] = None, - tools_map: Optional[Dict] = None, - embeddings: Optional[Embeddings] = None, -) -> AsyncGenerator[Dict, None]: - """Execute a VectorReact stream using LangGraph with vector store integration.""" - # Initialize service and run stream - embeddings = embeddings or OpenAIEmbeddings() - service = VectorLangGraphService( - collection_names=collection_names, - embeddings=embeddings, - ) - - async for chunk in service.execute_stream(history, input_str, persona, tools_map): - yield chunk diff --git a/services/workflows/web_search.py b/services/workflows/web_search.py deleted file mode 100644 index e7a3155f..00000000 --- a/services/workflows/web_search.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Web search workflow implementation using OpenAI Assistant API.""" - -import asyncio -import json -from typing import Any, Dict, List, Optional - -from langchain_core.messages import AIMessage, HumanMessage -from langgraph.graph import StateGraph -from openai import OpenAI -from openai.types.beta.assistant import Assistant -from openai.types.beta.thread import Thread -from openai.types.beta.threads.thread_message import ThreadMessage - -from lib.logger import configure_logger -from services.workflows.base import BaseWorkflow, WebSearchCapability -from services.workflows.vector import VectorRetrievalCapability - -logger = configure_logger(__name__) - - -class WebSearchWorkflow(BaseWorkflow, WebSearchCapability, VectorRetrievalCapability): - """Workflow that combines web search with vector retrieval capabilities using OpenAI Assistant.""" - - def __init__(self, **kwargs): - """Initialize the workflow. - - Args: - **kwargs: Additional arguments passed to parent classes - """ - super().__init__(**kwargs) - self.search_results_cache = {} - self.client = OpenAI() - # Create an assistant with web browsing capability - self.assistant: Assistant = self.client.beta.assistants.create( - name="Web Search Assistant", - description="Assistant that helps with web searches", - model="gpt-4-turbo-preview", - tools=[{"type": "retrieval"}, {"type": "web_browser"}], - instructions="""You are a web search assistant. Your primary task is to: - 1. Search the web for relevant information - 2. Extract key information from web pages - 3. Provide detailed, accurate responses with source URLs - 4. Format responses as structured data with content and metadata - Always include source URLs in your responses.""", - ) - - async def search_web(self, query: str, **kwargs) -> List[Dict[str, Any]]: - """Search the web using OpenAI Assistant API. - - Args: - query: The search query - **kwargs: Additional search parameters - - Returns: - List of search results with content and metadata - """ - try: - # Check cache first - if query in self.search_results_cache: - logger.info(f"Using cached results for query: {query}") - return self.search_results_cache[query] - - # Create a new thread for this search - thread: Thread = self.client.beta.threads.create() - - # Add the user's message to the thread - self.client.beta.threads.messages.create( - thread_id=thread.id, - role="user", - content=f"Search the web for: {query}. Please provide detailed information with source URLs.", - ) - - # Run the assistant - run = self.client.beta.threads.runs.create( - thread_id=thread.id, assistant_id=self.assistant.id - ) - - # Wait for completion - while True: - run_status = self.client.beta.threads.runs.retrieve( - thread_id=thread.id, run_id=run.id - ) - if run_status.status == "completed": - break - elif run_status.status in ["failed", "cancelled", "expired"]: - raise Exception( - f"Assistant run failed with status: {run_status.status}" - ) - await asyncio.sleep(1) # Wait before checking again - - # Get the assistant's response - messages: List[ThreadMessage] = self.client.beta.threads.messages.list( - thread_id=thread.id - ) - - # Process the response into our document format - documents = [] - for message in messages: - if message.role == "assistant": - for content in message.content: - if content.type == "text": - # Extract URLs from annotations if available - urls = [] - if message.metadata and "citations" in message.metadata: - urls = [ - cite["url"] - for cite in message.metadata["citations"] - ] - - # Create document with content and metadata - doc = { - "page_content": content.text, - "metadata": { - "type": "web_search_result", - "source_urls": urls, - "query": query, - "timestamp": message.created_at, - }, - } - documents.append(doc) - - # Cache the results - self.search_results_cache[query] = documents - - logger.info(f"Web search completed with {len(documents)} results") - return documents - - except Exception as e: - logger.error(f"Web search failed: {str(e)}") - return [] - - async def execute(self, query: str, **kwargs) -> Dict[str, Any]: - """Execute the web search workflow. - - This workflow: - 1. Searches the web for relevant information - 2. Processes and stores the results - 3. Combines with vector retrieval if available - - Args: - query: The search query - **kwargs: Additional execution arguments - - Returns: - Dict containing search results and any additional data - """ - try: - # Perform web search - web_results = await self.search_web(query, **kwargs) - - # Cache results - self.search_results_cache[query] = web_results - - # Combine with vector retrieval if available - combined_results = web_results - try: - vector_results = await self.retrieve_from_vectorstore(query, **kwargs) - combined_results.extend(vector_results) - except Exception as e: - logger.warning( - f"Vector retrieval failed, using only web results: {str(e)}" - ) - - return { - "query": query, - "results": combined_results, - "source": "web_search_workflow", - "metadata": { - "num_web_results": len(web_results), - "has_vector_results": ( - bool(vector_results) if "vector_results" in locals() else False - ), - }, - } - - except Exception as e: - logger.error(f"Web search workflow execution failed: {str(e)}") - raise - - def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: - """Integrate web search workflow with a graph. - - Args: - graph: The graph to integrate with - **kwargs: Additional integration arguments - """ - # Add web search node - graph.add_node("web_search", self.search_web) - - # Add vector retrieval node if available - try: - graph.add_node("vector_retrieval", self.retrieve_from_vectorstore) - - # Connect nodes - graph.add_edge("web_search", "vector_retrieval") - except Exception as e: - logger.warning(f"Vector retrieval integration failed: {str(e)}") - - # Add result processing node - graph.add_node("process_results", self._process_results) - graph.add_edge("vector_retrieval", "process_results") - - async def _process_results( - self, - web_results: List[Dict[str, Any]], - vector_results: Optional[List[Dict[str, Any]]] = None, - ) -> Dict[str, Any]: - """Process and combine search results. - - Args: - web_results: Results from web search - vector_results: Optional results from vector retrieval - - Returns: - Processed and combined results - """ - combined_results = web_results.copy() - if vector_results: - combined_results.extend(vector_results) - - # Deduplicate results based on content similarity - seen_contents = set() - unique_results = [] - for result in combined_results: - content = result.get("page_content", "") - content_hash = hash(content) - if content_hash not in seen_contents: - seen_contents.add(content_hash) - unique_results.append(result) - - return { - "results": unique_results, - "metadata": { - "num_web_results": len(web_results), - "num_vector_results": len(vector_results) if vector_results else 0, - "num_unique_results": len(unique_results), - }, - } diff --git a/services/workflows/web_search_mixin.py b/services/workflows/web_search_mixin.py new file mode 100644 index 00000000..e3bf89a6 --- /dev/null +++ b/services/workflows/web_search_mixin.py @@ -0,0 +1,197 @@ +"""Web search mixin for workflows, providing web search capabilities using OpenAI Responses API.""" + +from typing import Any, Dict, List, Tuple + +from langgraph.graph import StateGraph +from openai import OpenAI + +from lib.logger import configure_logger +from services.workflows.base import BaseWorkflowMixin + +logger = configure_logger(__name__) + + +class WebSearchCapability(BaseWorkflowMixin): + """Mixin that adds web search capabilities to a workflow using OpenAI Responses API.""" + + def __init__(self, *args, **kwargs): + """Initialize the web search capability.""" + # Initialize parent class if it exists + super().__init__(*args, **kwargs) if hasattr(super(), "__init__") else None + # Initialize our attributes + self._init_web_search() + + def _init_web_search(self) -> None: + """Initialize web search attributes if not already initialized.""" + if not hasattr(self, "search_results_cache"): + self.search_results_cache = {} + if not hasattr(self, "client"): + self.client = OpenAI() + + async def search_web( + self, query: str, **kwargs + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + """Search the web using OpenAI Responses API. + + Args: + query: The search query + **kwargs: Additional search parameters like user_location and search_context_size + + Returns: + Tuple containing list of search results and token usage dict. + """ + try: + # Ensure initialization + self._init_web_search() + + # Check cache first + if query in self.search_results_cache: + logger.info(f"Using cached results for query: {query}") + return self.search_results_cache[query], { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + + # Configure web search tool + tool_config = { + "type": "web_search_preview", + "search_context_size": kwargs.get("search_context_size", "medium"), + } + + # Add user location if provided + if "user_location" in kwargs: + tool_config["user_location"] = kwargs["user_location"] + + # Make the API call + response = self.client.responses.create( + model="gpt-4.1", tools=[tool_config], input=query + ) + + token_usage = response.usage # Access the usage object + standardized_usage = { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + if token_usage: # Check if usage data exists + standardized_usage = { + "input_tokens": token_usage.prompt_tokens, # Access via attribute + "output_tokens": token_usage.completion_tokens, # Access via attribute + "total_tokens": token_usage.total_tokens, # Access via attribute + } + + logger.debug(f"Web search response: {response}") + logger.debug(f"Web search token usage: {standardized_usage}") + # Process the response into our document format + documents = [] + + # Access the output text directly + if hasattr(response, "output_text"): + text_content = response.output_text + source_urls = [] + + # Try to extract citations if available + if hasattr(response, "citations"): + source_urls = [ + { + "url": citation.url, + "title": getattr(citation, "title", ""), + "start_index": getattr(citation, "start_index", 0), + "end_index": getattr(citation, "end_index", 0), + } + for citation in response.citations + if hasattr(citation, "url") + ] + + # Ensure we always have at least one URL entry + if not source_urls: + source_urls = [ + { + "url": "No source URL available", + "title": "Generated Response", + "start_index": 0, + "end_index": len(text_content), + } + ] + + # Create document with content + doc = { + "page_content": text_content, + "metadata": { + "type": "web_search_result", + "source_urls": source_urls, + "query": query, + "timestamp": None, + }, + } + documents.append(doc) + + # Cache the results + self.search_results_cache[query] = documents + + logger.info(f"Web search completed with {len(documents)} results") + return documents, standardized_usage + + except Exception as e: + logger.error(f"Web search failed: {str(e)}") + # Return empty list and zero usage on error + error_doc = [ + { + "page_content": "Web search failed to return results.", + "metadata": { + "type": "web_search_result", + "source_urls": [ + { + "url": "Error occurred during web search", + "title": "Error", + "start_index": 0, + "end_index": 0, + } + ], + "query": query, + "timestamp": None, + }, + } + ] + return error_doc, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + + def integrate_with_graph(self, graph: StateGraph, **kwargs) -> None: + """Integrate web search capability with a graph. + + This adds the web search capability to the graph by adding a node + that can perform web searches when needed. + + Args: + graph: The graph to integrate with + **kwargs: Additional arguments specific to web search including: + - search_context_size: "low", "medium", or "high" + - user_location: dict with type, country, city, region + """ + # Add web search node + graph.add_node("web_search", self.search_web) + + # Add result processing node if needed + if "process_results" not in graph.nodes: + graph.add_node("process_results", self._process_results) + graph.add_edge("web_search", "process_results") + + async def _process_results( + self, web_results: List[Dict[str, Any]], **kwargs + ) -> Dict[str, Any]: + """Process web search results. + + Args: + web_results: Results from web search + **kwargs: Additional processing arguments + + Returns: + Processed results with metadata + """ + return { + "results": web_results, + "metadata": { + "num_web_results": len(web_results), + "source_types": ["web_search"], + }, + } diff --git a/services/workflows/workflow_service.py b/services/workflows/workflow_service.py index a4ecbdf8..2a4a6921 100644 --- a/services/workflows/workflow_service.py +++ b/services/workflows/workflow_service.py @@ -16,12 +16,11 @@ from lib.logger import configure_logger from services.workflows.base import ExecutionError, StreamingError -from services.workflows.react import ( - LangGraphService, +from services.workflows.chat import ( + ChatService, MessageProcessor, StreamingCallbackHandler, ) -from services.workflows.vector_react import VectorLangGraphService logger = configure_logger(__name__) @@ -508,64 +507,36 @@ def build(self, **extra_kwargs) -> Any: class WorkflowFactory: - """Factory for creating workflow service instances.""" + """Factory for creating workflow service instances. Only ChatService is used.""" @classmethod def create_workflow_service( cls, - workflow_type: str = "react", + workflow_type: str = "chat", vector_collections: Optional[Union[str, List[str]]] = None, embeddings: Optional[Embeddings] = None, **kwargs, ) -> WorkflowService: - """Create a workflow service instance based on the workflow type. + """Create a workflow service instance. Always returns ChatService. Args: - workflow_type: Type of workflow to create ("react", "preplan", "vector", "vector_preplan") + workflow_type: Type of workflow to create (ignored, always uses ChatService) vector_collections: Vector collection name(s) for vector workflows embeddings: Embeddings model for vector workflows **kwargs: Additional parameters to pass to the service Returns: - An instance of a WorkflowService implementation + An instance of ChatService """ - # Import service classes here to avoid circular imports - from services.workflows.preplan_react import PreplanLangGraphService - from services.workflows.vector_preplan_react import ( - VectorPreplanLangGraphService, - ) - - # Map workflow types to their service classes - service_map = { - "react": LangGraphService, - "preplan": PreplanLangGraphService, - "vector": VectorLangGraphService, - "vector_preplan": VectorPreplanLangGraphService, - } - - if workflow_type not in service_map: - raise ValueError(f"Unsupported workflow type: {workflow_type}") - - service_class = service_map[workflow_type] - - # Handle vector-based workflow special cases - if workflow_type in ["vector", "vector_preplan"]: - if not vector_collections: - raise ValueError( - f"Vector collection name(s) required for {workflow_type} workflow" - ) - + if vector_collections is not None: if not embeddings: embeddings = OpenAIEmbeddings() - - return service_class( + return ChatService( collection_names=vector_collections, embeddings=embeddings, **kwargs, ) - - # For other workflow types - return service_class(**kwargs) + return ChatService(**kwargs) async def execute_workflow_stream( @@ -578,10 +549,10 @@ async def execute_workflow_stream( embeddings: Optional[Embeddings] = None, **kwargs, ) -> AsyncGenerator[Dict, None]: - """Unified interface for executing any workflow stream. + """Unified interface for executing any workflow stream. Uses ChatService for all workflows. Args: - workflow_type: Type of workflow to execute + workflow_type: Type of workflow to execute (ignored) history: Conversation history input_str: Current user input persona: Optional persona to use @@ -599,8 +570,6 @@ async def execute_workflow_stream( embeddings=embeddings, **kwargs, ) - - # Execute the stream through the service's execute_stream method async for chunk in service.execute_stream( history=history, input_str=input_str, diff --git a/tests/services/workflows/test_vector_react.py b/tests/services/workflows/test_vector_react.py index ffd3ac9e..0bc45eef 100644 --- a/tests/services/workflows/test_vector_react.py +++ b/tests/services/workflows/test_vector_react.py @@ -5,11 +5,11 @@ from langchain_core.documents import Document -from services.workflows.vector_react import ( +from services.workflows.chat import ( VectorLangGraphService, - VectorReactWorkflow, - add_documents_to_vectors, + execute_vector_langgraph_stream, ) +from services.workflows.vector_mixin import add_documents_to_vectors class TestVectorOperations(unittest.TestCase): diff --git a/vector_react_example.py b/vector_react_example.py index 56391f30..11e42dca 100644 --- a/vector_react_example.py +++ b/vector_react_example.py @@ -15,10 +15,8 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter from backend.factory import backend -from services.workflows.vector_react import ( - add_documents_to_vectors, - execute_vector_langgraph_stream, -) +from services.workflows.chat import VectorLangGraphService +from services.workflows.vector_mixin import add_documents_to_vectors dotenv.load_dotenv()