Skip to content

Commit 51ad84e

Browse files
committed
Refactor codebase: remove v1 API, add visualization server and agent orchestrator
1 parent 1ecc2a4 commit 51ad84e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+10508
-3256
lines changed

.coveragerc

Lines changed: 0 additions & 13 deletions
This file was deleted.

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
33
rev: v4.6.0
44
hooks:
5-
- id: check-added-large-files
5+
# - id: check-added-large-files
66
- id: check-ast
77
- id: check-builtin-literals
88
- id: check-case-conflict

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.12

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ COPY ./app /code/app
1414
RUN groupadd -r nonroot && useradd -r -g nonroot nonroot
1515
USER nonroot
1616

17-
CMD ["/code/.venv/bin/uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
17+
CMD ["/code/.venv/bin/uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
1818
# To build the Docker image, use the following command:
1919
# docker build -t my-fastapi-app .
2020

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pytest:
2-
pytest --cov-report term --cov=app --cov-config=.coveragerc ./tests -v --ignore=tests/tools/
2+
pytest --cov-report term --cov=app ./tests
33

44
pre-commit:
55
pre-commit run --all-files
@@ -19,7 +19,7 @@ ruff-format:
1919
ruff: ruff-check ruff-format
2020

2121
run:
22-
uv run uvicorn app.main:app --port 8000 --reload
22+
uv run uvicorn main:app --port 8000 --reload
2323

2424
ops: pytest pre-commit clean ruff
2525
@echo "\033[92mAll operations completed successfully.\033[0m"

app/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# app/__init__.py
2+
3+
from .agent_orchestrator import determine_data_source, process_user_input_stream, run_agent_orchestrator
4+
from .tools.data_analyst_agent import DataVisualizationAgent
5+
from .tools.sql_data_analyst_agent import SQLDataAnalysisAgent
6+
from .visualization_server import VisualizationServer
7+
8+
9+
__all__ = [
10+
"determine_data_source",
11+
"process_user_input_stream",
12+
"run_agent_orchestrator",
13+
"VisualizationServer",
14+
"DataVisualizationAgent",
15+
"SQLDataAnalysisAgent",
16+
]

app/agent_orchestrator.py

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
import os
2+
3+
from dataclasses import dataclass
4+
from typing import Any, TypedDict
5+
6+
import pandas as pd
7+
import sqlalchemy as sql
8+
9+
from langchain_core.language_models import BaseChatModel
10+
from langchain_openai import ChatOpenAI
11+
from pydantic_ai import Agent, RunContext
12+
from pydantic_ai.usage import UsageLimits
13+
14+
from app.tools.data_analyst_agent import DataVisualizationAgent
15+
from app.tools.sql_data_analyst_agent import SQLDataAnalysisAgent
16+
17+
18+
# Define our dependency type for orchestration
19+
@dataclass
20+
class OrchestratorDependency:
21+
"""Dependencies for the orchestrator agent."""
22+
23+
user_prompt: str
24+
model: BaseChatModel
25+
data: pd.DataFrame | None = None
26+
db_connection: sql.engine.base.Connection | None = None
27+
usage_limits: UsageLimits | None = None
28+
29+
30+
# Type for streamed results
31+
class AnalysisResult(TypedDict):
32+
success: bool
33+
message: str
34+
visualization_path: str | None
35+
error: str | None
36+
data_summary: dict[str, Any] | None
37+
38+
39+
# Create our master orchestrator agent
40+
orchestrator_agent = Agent(
41+
"openai:gpt-4o",
42+
deps_type=OrchestratorDependency,
43+
result_type=AnalysisResult,
44+
system_prompt="""
45+
You are an expert data analysis orchestrator. Your job is to:
46+
1. Understand user requests related to data analysis and visualization
47+
2. Determine whether to use SQL database analysis or direct DataFrame analysis
48+
3. Call the appropriate agent to handle the request
49+
4. Return results in a clear, organized manner
50+
51+
For SQL database requests, use the sql_agent tool.
52+
For DataFrame visualization requests, use the visualization_agent tool.
53+
""",
54+
)
55+
56+
57+
@orchestrator_agent.tool
58+
async def sql_agent(ctx: RunContext[OrchestratorDependency], query: str) -> dict[str, Any]: # noqa: D417
59+
"""Process a SQL database query and visualization request.
60+
61+
Args:
62+
query: The user's analysis request/question about the database
63+
64+
Returns:
65+
A dictionary with the analysis results
66+
67+
"""
68+
if ctx.deps.db_connection is None:
69+
return {"error": "Database connection is required but not provided"}
70+
71+
# Initialize the SQL agent with the provided connection
72+
sql_agent = SQLDataAnalysisAgent(
73+
model=ctx.deps.model, connection=ctx.deps.db_connection, n_samples=5, log=True, log_path="logs/", verbose=True
74+
)
75+
76+
# Execute the query but don't auto-display (we'll handle that)
77+
results = sql_agent.invoke_agent(query, auto_display=False)
78+
79+
# Check for errors
80+
if results.get("error"):
81+
return {"success": False, "error": results.get("error"), "message": f"Analysis failed: {results.get('error')}"}
82+
83+
# Get visualization path if available
84+
vis_path = None
85+
if results.get("plotly_graph"):
86+
if not os.path.exists("visualizations"):
87+
os.makedirs("visualizations")
88+
vis_path = "visualizations/analysis_result.html"
89+
results.get("plotly_graph").write_html(vis_path)
90+
91+
# Get data summary
92+
data_summary = None
93+
df = sql_agent.get_data_sql()
94+
if df is not None and not isinstance(df, str):
95+
data_summary = {
96+
"shape": df.shape,
97+
"columns": list(df.columns),
98+
"sample": df.head(5).to_dict() if len(df) > 0 else {},
99+
}
100+
101+
return {
102+
"success": True,
103+
"message": "SQL analysis completed successfully",
104+
"visualization_path": vis_path,
105+
"sql_query": sql_agent.get_sql_query_code(),
106+
"data_summary": data_summary,
107+
}
108+
109+
110+
@orchestrator_agent.tool
111+
async def visualization_agent(ctx: RunContext[OrchestratorDependency], instructions: str) -> dict[str, Any]: # noqa: D417
112+
"""Create a visualization from a DataFrame based on instructions.
113+
114+
Args:
115+
instructions: The visualization instructions
116+
117+
Returns:
118+
A dictionary with the visualization results
119+
120+
"""
121+
if ctx.deps.data is None:
122+
return {"error": "DataFrame is required but not provided"}
123+
124+
# Initialize the visualization agent
125+
vis_agent = DataVisualizationAgent(model=ctx.deps.model, log=True, log_path="logs/")
126+
127+
# Generate the visualization
128+
response = vis_agent.generate_visualization(data=ctx.deps.data, instructions=instructions)
129+
130+
# Check for errors
131+
if not response.get("success", False):
132+
return {
133+
"success": False,
134+
"error": response.get("error", "Unknown error"),
135+
"message": f"Visualization failed: {response.get('error', 'Unknown error')}",
136+
}
137+
138+
# Save the visualization if available
139+
vis_path = None
140+
fig = vis_agent.get_plotly_figure()
141+
if fig:
142+
if not os.path.exists("visualizations"):
143+
os.makedirs("visualizations")
144+
vis_path = "visualizations/analysis_result.html"
145+
fig.write_html(vis_path)
146+
147+
return {
148+
"success": True,
149+
"message": "Visualization created successfully",
150+
"visualization_path": vis_path,
151+
"visualization_code": vis_agent.get_visualization_code(),
152+
"explanation": response.get("explanation", ""),
153+
}
154+
155+
156+
@orchestrator_agent.tool
157+
async def determine_data_source(ctx: RunContext[OrchestratorDependency], query: str) -> str: # noqa: D417
158+
"""Determine whether to use SQL database or DataFrame analysis based on the query.
159+
160+
Args:
161+
query: The user's analysis request/question
162+
163+
Returns:
164+
A recommendation for which data source to use ("sql" or "dataframe")
165+
166+
"""
167+
# Check if we have both options available
168+
has_db = ctx.deps.db_connection is not None
169+
has_df = ctx.deps.data is not None
170+
171+
# If we only have one option, use that
172+
if has_db and not has_df:
173+
return "sql"
174+
if has_df and not has_db:
175+
return "dataframe"
176+
177+
# If we have both options, determine based on query content
178+
sql_keywords = ["sql", "database", "table", "query", "join", "select", "from", "where"]
179+
has_sql_keywords = any(keyword in query.lower() for keyword in sql_keywords)
180+
181+
if has_sql_keywords:
182+
return "sql"
183+
return "dataframe"
184+
185+
186+
async def process_user_input(
187+
user_input: str,
188+
data: pd.DataFrame = None,
189+
db_connection: sql.engine.base.Connection = None,
190+
usage_limits: UsageLimits = None,
191+
) -> dict[str, Any]:
192+
"""Process a user input with the orchestrator agent.
193+
194+
Args:
195+
user_input: The user's prompt/question
196+
data: Optional DataFrame to analyze
197+
db_connection: Optional database connection
198+
usage_limits: Optional usage limits
199+
200+
Returns:
201+
The results of the analysis
202+
203+
"""
204+
# Set up the LLM
205+
model = ChatOpenAI(model_name="gpt-4o")
206+
207+
# Create dependencies
208+
deps = OrchestratorDependency(
209+
user_prompt=user_input, model=model, data=data, db_connection=db_connection, usage_limits=usage_limits
210+
)
211+
212+
# Run the agent
213+
result = await orchestrator_agent.run(user_input, deps=deps, usage_limits=usage_limits)
214+
215+
return result.data
216+
217+
218+
async def run_agent_orchestrator(
219+
user_input: str, data_path: str = None, db_url: str = None, usage_limits: UsageLimits = None
220+
) -> dict[str, Any]:
221+
"""Run the agent orchestrator with file path or database URL.
222+
223+
Args:
224+
user_input: The user's prompt/question
225+
data_path: Optional path to a data file (CSV, Excel)
226+
db_url: Optional database URL
227+
usage_limits: Optional usage limits
228+
229+
Returns:
230+
The results of the analysis
231+
232+
"""
233+
data = None
234+
db_connection = None
235+
236+
# Load data if provided
237+
if data_path:
238+
if data_path.endswith(".csv"):
239+
data = pd.read_csv(data_path)
240+
elif data_path.endswith((".xls", ".xlsx")):
241+
data = pd.read_excel(data_path)
242+
else:
243+
return {"error": "Unsupported file format. Please use .csv, .xls, or .xlsx"}
244+
245+
# Set up database connection if provided
246+
if db_url:
247+
try:
248+
engine = sql.create_engine(db_url)
249+
db_connection = engine.connect()
250+
except Exception as e:
251+
return {"error": f"Failed to connect to database: {str(e)}"}
252+
253+
try:
254+
# Process the request
255+
result = await process_user_input(
256+
user_input=user_input, data=data, db_connection=db_connection, usage_limits=usage_limits
257+
)
258+
259+
# Clean up database connection if we created one
260+
if db_connection:
261+
db_connection.close()
262+
263+
return result
264+
except Exception as e:
265+
if db_connection:
266+
db_connection.close()
267+
return {"error": str(e)}
268+
269+
270+
# Streaming version of the process_user_input function
271+
async def process_user_input_stream(
272+
user_input: str,
273+
data: pd.DataFrame = None,
274+
db_connection: sql.engine.base.Connection = None,
275+
usage_limits: UsageLimits = None,
276+
):
277+
"""Process a user input with the orchestrator agent and stream the results.
278+
279+
Args:
280+
user_input: The user's prompt/question
281+
data: Optional DataFrame to analyze
282+
db_connection: Optional database connection
283+
usage_limits: Optional usage limits
284+
285+
Returns:
286+
An async generator that yields progress updates
287+
288+
"""
289+
# Set up the LLM
290+
model = ChatOpenAI(model_name="gpt-4o")
291+
292+
# Create dependencies
293+
deps = OrchestratorDependency(
294+
user_prompt=user_input, model=model, data=data, db_connection=db_connection, usage_limits=usage_limits
295+
)
296+
297+
# First yield the starting message
298+
yield "Starting analysis...\n"
299+
300+
try:
301+
# Run the agent and get the result (non-streaming first)
302+
run_result = await orchestrator_agent.run(user_input, deps=deps, usage_limits=usage_limits)
303+
304+
# Yield progress updates
305+
yield "Processing data and creating visualization...\n"
306+
307+
# Get the final result
308+
result = run_result.data
309+
310+
# Yield the final result summary
311+
if result.get("success", False):
312+
yield "\nAnalysis completed successfully!\n"
313+
if result.get("visualization_path"):
314+
yield f"Visualization saved to: {result.get('visualization_path')}\n"
315+
yield "You can view the visualization in your browser.\n"
316+
317+
if result.get("data_summary"):
318+
yield "\nData Summary:\n"
319+
shape = result.get("data_summary", {}).get("shape")
320+
if shape:
321+
yield f"- Shape: {shape[0]} rows × {shape[1]} columns\n"
322+
323+
columns = result.get("data_summary", {}).get("columns")
324+
if columns:
325+
yield f"- Columns: {', '.join(columns)}\n"
326+
else:
327+
yield f"\nAnalysis failed: {result.get('error', 'Unknown error')}\n"
328+
329+
except Exception as e:
330+
# Handle any exceptions
331+
yield f"\nError during analysis: {str(e)}\n"

0 commit comments

Comments
 (0)