Skip to content

Commit 1ecc2a4

Browse files
committed
Add agent workflow with multi-agent routing and query processing
1 parent be7cc2f commit 1ecc2a4

File tree

8 files changed

+1001
-33
lines changed

8 files changed

+1001
-33
lines changed

app/v1/routers/agent_workflow/__init__.py

Whitespace-only changes.
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from typing import Any, Literal, Optional # noqa: UP035
2+
3+
from pydantic import BaseModel
4+
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
5+
6+
from app.tools.data_analyst_agent import DataAnalystAgent
7+
from app.tools.sql_data_analyst_agent import SqlDataAnalystAgent
8+
from app.v1.routers.agent_workflow.persistence import ChromaDBStatePersistence, get_chroma_client
9+
10+
11+
class AgentState(BaseModel):
12+
"""State model for the agent workflow graph."""
13+
14+
query: str
15+
result: Optional[dict[str, Any]] = None
16+
agent_used: Optional[str] = None
17+
18+
19+
class QueryClassifier(BaseNode[AgentState]):
20+
"""Node to classify the query and determine which agent to use."""
21+
22+
async def run(self, ctx: GraphRunContext[AgentState]) -> Literal["use_sql_agent", "use_data_agent"]:
23+
"""Classify the query to determine which agent to use.
24+
25+
Returns:
26+
str: Either "use_sql_agent" or "use_data_agent"
27+
28+
"""
29+
# Simple classifier based on keywords
30+
query = ctx.state.query.lower()
31+
32+
# Keywords that suggest SQL operations
33+
sql_keywords = [
34+
"sql",
35+
"query",
36+
"table",
37+
"database",
38+
"select",
39+
"join",
40+
"where",
41+
"group by",
42+
"order by",
43+
"having",
44+
"count",
45+
"sum",
46+
"average",
47+
"filter",
48+
"find records",
49+
"search records",
50+
]
51+
52+
# Keywords that suggest data analysis or visualization
53+
data_keywords = [
54+
"analyze",
55+
"chart",
56+
"graph",
57+
"plot",
58+
"visualization",
59+
"trend",
60+
"pattern",
61+
"correlation",
62+
"distribution",
63+
"histogram",
64+
"scatter plot",
65+
"bar chart",
66+
"pie chart",
67+
"dashboard",
68+
]
69+
70+
# Count matches for each category
71+
sql_score = sum(1 for keyword in sql_keywords if keyword in query)
72+
data_score = sum(1 for keyword in data_keywords if keyword in query)
73+
74+
# Determine which agent to use based on keyword matches
75+
if sql_score > data_score:
76+
return "use_sql_agent"
77+
return "use_data_agent"
78+
79+
80+
class UseSqlAgent(BaseNode[AgentState]):
81+
"""Node to process the query using the SQL Data Analyst Agent."""
82+
83+
async def run(self, ctx: GraphRunContext[AgentState]) -> End[dict[str, Any]]:
84+
"""Process the query using the SQL Data Analyst Agent.
85+
86+
Returns:
87+
End: The result from the SQL Data Analyst Agent
88+
89+
"""
90+
sql_agent = SqlDataAnalystAgent()
91+
result = await sql_agent.process_query(ctx.state.query)
92+
93+
# Update the state
94+
ctx.state.result = result
95+
ctx.state.agent_used = "sql_data_analyst"
96+
97+
# Return the result and end the graph execution
98+
return End({"result": result, "agent_used": "sql_data_analyst"})
99+
100+
101+
class UseDataAgent(BaseNode[AgentState]):
102+
"""Node to process the query using the Data Analyst Agent."""
103+
104+
async def run(self, ctx: GraphRunContext[AgentState]) -> End[dict[str, Any]]:
105+
"""Process the query using the Data Analyst Agent.
106+
107+
Returns:
108+
End: The result from the Data Analyst Agent
109+
110+
"""
111+
data_agent = DataAnalystAgent()
112+
result = data_agent.process_query(ctx.state.query)
113+
114+
# Update the state
115+
ctx.state.result = result
116+
ctx.state.agent_used = "data_analyst"
117+
118+
# Return the result and end the graph execution
119+
return End({"result": result, "agent_used": "data_analyst"})
120+
121+
122+
class AgentWorkflowGraph:
123+
"""Multi-agent workflow graph implementation using pydantic-ai and ChromaDB.
124+
125+
This class manages the workflow that selects and runs the appropriate agent
126+
based on the user's query.
127+
"""
128+
129+
def __init__(self):
130+
"""Initialize the agent workflow graph."""
131+
# Create the graph with nodes and edges
132+
nodes = {
133+
"start": QueryClassifier(),
134+
"use_sql_agent": UseSqlAgent(),
135+
"use_data_agent": UseDataAgent(),
136+
}
137+
138+
edges = {
139+
"start": {
140+
"use_sql_agent": "use_sql_agent",
141+
"use_data_agent": "use_data_agent",
142+
}
143+
}
144+
145+
# Create the graph using named parameters
146+
self.graph = Graph(
147+
state_type=AgentState,
148+
nodes=nodes,
149+
edges=edges,
150+
)
151+
152+
# Set up persistence with ChromaDB
153+
chroma_client = get_chroma_client()
154+
self.persistence = ChromaDBStatePersistence(
155+
chroma_client=chroma_client,
156+
collection_name="agent_workflow_states",
157+
)
158+
159+
async def process_query(self, query: str) -> dict[str, Any]:
160+
"""Process a user query through the multi-agent workflow.
161+
162+
Args:
163+
query: The user's query string
164+
165+
Returns:
166+
dict: A dictionary containing the result and the agent used
167+
168+
"""
169+
# Initialize the state with the query
170+
state = AgentState(query=query)
171+
172+
# Run the graph with the state
173+
return await self.graph.arun(state, state_persistence=self.persistence)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import json
2+
import os
3+
import uuid
4+
5+
from typing import Any, Optional # noqa: UP035
6+
7+
import chromadb
8+
9+
from chromadb.config import Settings
10+
from pydantic_graph.persistence import BaseStatePersistence
11+
12+
13+
def get_chroma_client(persist_directory: Optional[str] = None) -> chromadb.Client:
14+
"""Get or create a ChromaDB client.
15+
16+
Args:
17+
persist_directory: Optional directory to persist the ChromaDB data.
18+
If not provided, an in-memory client will be used.
19+
20+
Returns:
21+
chromadb.Client: A ChromaDB client instance
22+
23+
"""
24+
if persist_directory:
25+
# Create the directory if it doesn't exist
26+
os.makedirs(persist_directory, exist_ok=True)
27+
28+
# Return a persistent client
29+
return chromadb.PersistentClient(path=persist_directory, settings=Settings(anonymized_telemetry=False))
30+
# Return an in-memory client
31+
return chromadb.Client(Settings(anonymized_telemetry=False, is_persistent=False))
32+
33+
34+
class ChromaDBStatePersistence(BaseStatePersistence):
35+
"""Custom implementation of state persistence using ChromaDB."""
36+
37+
def __init__(
38+
self,
39+
chroma_client: chromadb.Client,
40+
collection_name: str = "graph_states",
41+
):
42+
"""Initialize the ChromaDB state persistence.
43+
44+
Args:
45+
chroma_client: ChromaDB client instance
46+
collection_name: Name of the collection to store states
47+
48+
"""
49+
self.client = chroma_client
50+
self.collection = self.client.get_or_create_collection(collection_name)
51+
52+
async def save_state(self, state_id: str, state: dict[str, Any], metadata: Optional[dict[str, Any]] = None) -> None:
53+
"""Save a state to ChromaDB.
54+
55+
Args:
56+
state_id: Unique identifier for the state
57+
state: The state to save
58+
metadata: Optional metadata to store with the state
59+
60+
"""
61+
# Convert state to JSON string
62+
state_json = json.dumps(state)
63+
64+
# Prepare metadata
65+
meta = metadata or {}
66+
meta["state_id"] = state_id
67+
68+
# Use a unique document ID
69+
doc_id = f"{state_id}_{uuid.uuid4()}"
70+
71+
# Add the state to ChromaDB
72+
self.collection.add(documents=[state_json], metadatas=[meta], ids=[doc_id])
73+
74+
async def load_state(self, state_id: str) -> Optional[dict[str, Any]]:
75+
"""Load a state from ChromaDB.
76+
77+
Args:
78+
state_id: Unique identifier for the state
79+
80+
Returns:
81+
The state if found, None otherwise
82+
83+
"""
84+
# Query ChromaDB for the state
85+
results = self.collection.query(where={"state_id": state_id}, limit=1)
86+
87+
# Return the state if found
88+
if results["documents"] and results["documents"][0]:
89+
return json.loads(results["documents"][0])
90+
91+
return None
92+
93+
async def delete_state(self, state_id: str) -> None:
94+
"""Delete a state from ChromaDB.
95+
96+
Args:
97+
state_id: Unique identifier for the state
98+
99+
"""
100+
# Find all documents with the given state_id
101+
results = self.collection.query(where={"state_id": state_id}, limit=100)
102+
103+
# Delete the documents if found
104+
if results["ids"]:
105+
self.collection.delete(ids=results["ids"])
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from fastapi import APIRouter, HTTPException
2+
from pydantic import BaseModel
3+
4+
from app.v1.routers.agent_workflow.graph import AgentWorkflowGraph
5+
6+
7+
router = APIRouter()
8+
agent_workflow = AgentWorkflowGraph()
9+
10+
11+
class QueryRequest(BaseModel):
12+
"""Request model for agent workflow queries."""
13+
14+
query: str
15+
16+
17+
class QueryResponse(BaseModel):
18+
"""Response model for agent workflow results."""
19+
20+
result: dict
21+
agent_used: str
22+
23+
24+
@router.post("/query", response_model=QueryResponse)
25+
async def process_query(request: QueryRequest) -> QueryResponse:
26+
"""Process a user query through the multi-agent workflow.
27+
28+
The workflow will automatically select the appropriate agent based on the query.
29+
"""
30+
try:
31+
return await agent_workflow.process_query(request.query)
32+
except Exception as e:
33+
raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") from e

app/v1/routers/base_router.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from fastapi import APIRouter
22

3+
from app.v1.routers.agent_workflow import router as agent_workflow_router
34
from app.v1.routers.users import users_router
45

56

67
router = APIRouter(prefix="/v1")
78
router.include_router(users_router.router, prefix="/users", tags=["Users"])
9+
router.include_router(agent_workflow_router.router, prefix="/agent-workflow", tags=["Agent Workflow"])

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ requires-python = ">=3.9"
99
dependencies = [
1010
"agno>=1.2.6",
1111
"ai-data-science-team>=0.0.0.9016",
12+
"chromadb>=1.0.0",
1213
"commitizen>=4.4.1",
1314
"duckdb>=1.2.1",
1415
"fastapi[standard]>=0.115.8",
@@ -27,6 +28,7 @@ dependencies = [
2728
"phidata>=2.7.10",
2829
"pre-commit>=4.1.0",
2930
"pydantic-ai[logfire]>=0.0.46",
31+
"pydantic-graph>=0.0.52",
3032
"scikit-learn>=1.6.1",
3133
"streamlit>=1.44.1",
3234
]

0 commit comments

Comments
 (0)