Skip to content

Commit eba5788

Browse files
committed
add deep research tab
1 parent 09e3f21 commit eba5788

File tree

6 files changed

+512
-95
lines changed

6 files changed

+512
-95
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,4 @@ data/
187187

188188
# For Config Files (Current Settings)
189189
.config.pkl
190+
*.pdf

src/agent/deep_research/deep_research_agent.py

Lines changed: 39 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,11 @@
3535
logger = logging.getLogger(__name__)
3636

3737
# Constants
38-
TMP_DIR = Path("./tmp/deep_research")
39-
os.makedirs(TMP_DIR, exist_ok=True)
4038
REPORT_FILENAME = "report.md"
4139
PLAN_FILENAME = "research_plan.md"
4240
SEARCH_INFO_FILENAME = "search_info.json"
43-
MAX_PARALLEL_BROWSERS = 1
4441

4542
_AGENT_STOP_FLAGS = {}
46-
_BROWSER_AGENT_INSTANCES = {} # To store running browser agents for stopping
4743

4844

4945
async def run_single_browser_task(
@@ -119,6 +115,7 @@ async def run_single_browser_task(
119115
2. The title of the source page or document.
120116
3. The URL of the source.
121117
Focus on accuracy and relevance. Avoid irrelevant details.
118+
PDF cannot directly extract _content, please try to download first, then using read_file, if you can't save or read, please try other methods.
122119
"""
123120

124121
bu_agent_instance = BrowserUseAgent(
@@ -131,8 +128,7 @@ async def run_single_browser_task(
131128
)
132129

133130
# Store instance for potential stop() call
134-
task_key = f"{task_id}_{uuid.uuid4()}" # Unique key for this run
135-
_BROWSER_AGENT_INSTANCES[task_key] = bu_agent_instance
131+
task_key = f"{task_id}_{uuid.uuid4()}"
136132

137133
# --- Run with Stop Check ---
138134
# BrowserUseAgent needs to internally check a stop signal or have a stop method.
@@ -162,45 +158,46 @@ async def run_single_browser_task(
162158
logger.error(f"Error during browser task for query '{task_query}': {e}", exc_info=True)
163159
return {"query": task_query, "error": str(e), "status": "failed"}
164160
finally:
165-
if task_key in _BROWSER_AGENT_INSTANCES:
166-
del _BROWSER_AGENT_INSTANCES[task_key]
167161
if bu_browser_context:
168162
try:
169163
await bu_browser_context.close()
164+
bu_browser_context = None
170165
logger.info("Closed browser context.")
171166
except Exception as e:
172167
logger.error(f"Error closing browser context: {e}")
173168
if bu_browser:
174169
try:
175170
await bu_browser.close()
171+
bu_browser = None
176172
logger.info("Closed browser.")
177173
except Exception as e:
178174
logger.error(f"Error closing browser: {e}")
179175

180176

181177
class BrowserSearchInput(BaseModel):
182178
queries: List[str] = Field(
183-
description=f"List of distinct search queries (max {MAX_PARALLEL_BROWSERS}) to find information relevant to the research task.")
179+
description=f"List of distinct search queries to find information relevant to the research task.")
184180

185181

186182
async def _run_browser_search_tool(
187183
queries: List[str],
188184
task_id: str, # Injected dependency
189185
llm: Any, # Injected dependency
190-
browser_config: Dict[str, Any], # Injected dependency
191-
stop_event: threading.Event # Injected dependency
186+
browser_config: Dict[str, Any],
187+
stop_event: threading.Event,
188+
max_parallel_browsers: int = 1
192189
) -> List[Dict[str, Any]]:
193190
"""
194191
Internal function to execute parallel browser searches based on LLM-provided queries.
195192
Handles concurrency and stop signals.
196193
"""
197194

198195
# Limit queries just in case LLM ignores the description
199-
queries = queries[:MAX_PARALLEL_BROWSERS]
196+
queries = queries[:max_parallel_browsers]
200197
logger.info(f"[Browser Tool {task_id}] Running search for {len(queries)} queries: {queries}")
201198

202199
results = []
203-
semaphore = asyncio.Semaphore(MAX_PARALLEL_BROWSERS)
200+
semaphore = asyncio.Semaphore(max_parallel_browsers)
204201

205202
async def task_wrapper(query):
206203
async with semaphore:
@@ -240,7 +237,8 @@ def create_browser_search_tool(
240237
llm: Any,
241238
browser_config: Dict[str, Any],
242239
task_id: str,
243-
stop_event: threading.Event
240+
stop_event: threading.Event,
241+
max_parallel_browsers: int = 1,
244242
) -> StructuredTool:
245243
"""Factory function to create the browser search tool with necessary dependencies."""
246244
# Use partial to bind the dependencies that aren't part of the LLM call arguments
@@ -251,15 +249,15 @@ def create_browser_search_tool(
251249
llm=llm,
252250
browser_config=browser_config,
253251
stop_event=stop_event,
252+
max_parallel_browsers=max_parallel_browsers
254253
)
255254

256255
return StructuredTool.from_function(
257256
coroutine=bound_tool_func,
258257
name="parallel_browser_search",
259258
description=f"""Use this tool to actively search the web for information related to a specific research task or question.
260-
It runs up to {MAX_PARALLEL_BROWSERS} searches in parallel using a browser agent for better results than simple scraping.
261-
Provide a list of distinct search queries that are likely to yield relevant information.
262-
The tool returns a list of results, each containing the original query, the status (completed, failed, stopped), and the summarized information found (or an error message).""",
259+
It runs up to {max_parallel_browsers} searches in parallel using a browser agent for better results than simple scraping.
260+
Provide a list of distinct search queries that are likely to yield relevant information.""",
263261
args_schema=BrowserSearchInput,
264262
)
265263

@@ -747,7 +745,7 @@ def should_continue(state: DeepResearchState) -> str:
747745
return "end_run" # Should not happen if planning node ran correctly
748746

749747
# Check if there are pending steps in the plan
750-
if current_index < len(plan):
748+
if current_index < 2:
751749
logger.info(
752750
f"Plan has pending steps (current index {current_index}/{len(plan)}). Routing to Research Execution.")
753751
return "execute_research"
@@ -758,7 +756,7 @@ def should_continue(state: DeepResearchState) -> str:
758756

759757
# --- DeepSearchAgent Class ---
760758

761-
class DeepSearchAgent:
759+
class DeepResearchAgent:
762760
def __init__(self, llm: Any, browser_config: Dict[str, Any], mcp_server_config: Optional[Dict[str, Any]] = None):
763761
"""
764762
Initializes the DeepSearchAgent.
@@ -773,37 +771,44 @@ def __init__(self, llm: Any, browser_config: Dict[str, Any], mcp_server_config:
773771
self.browser_config = browser_config
774772
self.mcp_server_config = mcp_server_config
775773
self.mcp_client = None
774+
self.stopped = False
776775
self.graph = self._compile_graph()
777776
self.current_task_id: Optional[str] = None
778777
self.stop_event: Optional[threading.Event] = None
779778
self.runner: Optional[asyncio.Task] = None # To hold the asyncio task for run
780779

781-
async def _setup_tools(self, task_id: str, stop_event: threading.Event) -> List[Tool]:
780+
async def _setup_tools(self, task_id: str, stop_event: threading.Event, max_parallel_browsers: int = 1) -> List[
781+
Tool]:
782782
"""Sets up the basic tools (File I/O) and optional MCP tools."""
783783
tools = [WriteFileTool(), ReadFileTool(), ListDirectoryTool()] # Basic file operations
784784
browser_use_tool = create_browser_search_tool(
785785
llm=self.llm,
786786
browser_config=self.browser_config,
787787
task_id=task_id,
788-
stop_event=stop_event
788+
stop_event=stop_event,
789+
max_parallel_browsers=max_parallel_browsers
789790
)
790791
tools += [browser_use_tool]
791792
# Add MCP tools if config is provided
792793
if self.mcp_server_config:
793794
try:
794795
logger.info("Setting up MCP client and tools...")
795-
if self.mcp_client:
796-
await self.mcp_client.__aexit__(None, None, None)
797-
self.mcp_client = await setup_mcp_client_and_tools(self.mcp_server_config)
796+
if not self.mcp_client:
797+
self.mcp_client = await setup_mcp_client_and_tools(self.mcp_server_config)
798798
mcp_tools = self.mcp_client.get_tools()
799799
logger.info(f"Loaded {len(mcp_tools)} MCP tools.")
800800
tools.extend(mcp_tools)
801801
except Exception as e:
802802
logger.error(f"Failed to set up MCP tools: {e}", exc_info=True)
803803
elif self.mcp_server_config:
804804
logger.warning("MCP server config provided, but setup function unavailable.")
805+
tools_map = {tool.name: tool for tool in tools}
806+
return tools_map.values()
805807

806-
return tools
808+
async def close_mcp_client(self):
809+
if self.mcp_client:
810+
await self.mcp_client.__aexit__(None, None, None)
811+
self.mcp_client = None
807812

808813
def _compile_graph(self) -> StateGraph:
809814
"""Compiles the Langgraph state machine."""
@@ -836,7 +841,9 @@ def _compile_graph(self) -> StateGraph:
836841
app = workflow.compile()
837842
return app
838843

839-
async def run(self, topic: str, task_id: Optional[str] = None) -> Dict[str, Any]:
844+
async def run(self, topic: str, task_id: Optional[str] = None, save_dir: str = "./tmp/deep_research",
845+
max_parallel_browsers: int = 1) -> Dict[
846+
str, Any]:
840847
"""
841848
Starts the deep research process (Async Generator Version).
842849
@@ -853,15 +860,15 @@ async def run(self, topic: str, task_id: Optional[str] = None) -> Dict[str, Any]
853860
return {"status": "error", "message": "Agent already running.", "task_id": self.current_task_id}
854861

855862
self.current_task_id = task_id if task_id else str(uuid.uuid4())
856-
output_dir = os.path.join(TMP_DIR, self.current_task_id)
863+
output_dir = os.path.join(save_dir, self.current_task_id)
857864
os.makedirs(output_dir, exist_ok=True)
858865

859866
logger.info(f"[AsyncGen] Starting research task ID: {self.current_task_id} for topic: '{topic}'")
860867
logger.info(f"[AsyncGen] Output directory: {output_dir}")
861868

862869
self.stop_event = threading.Event()
863870
_AGENT_STOP_FLAGS[self.current_task_id] = self.stop_event
864-
agent_tools = await self._setup_tools(self.current_task_id, self.stop_event)
871+
agent_tools = await self._setup_tools(self.current_task_id, self.stop_event, max_parallel_browsers)
865872
initial_state: DeepResearchState = {
866873
"task_id": self.current_task_id,
867874
"topic": topic,
@@ -933,19 +940,7 @@ async def run(self, topic: str, task_id: Optional[str] = None) -> Dict[str, Any]
933940
# final_state will remain None or the state before the error
934941
finally:
935942
logger.info(f"Cleaning up resources for task {self.current_task_id}")
936-
task_id_to_clean = self.current_task_id # Store before potentially clearing
937-
if task_id_to_clean in _AGENT_STOP_FLAGS:
938-
del _AGENT_STOP_FLAGS[task_id_to_clean]
939-
# Stop any potentially lingering browser agents for this task
940-
await self._stop_lingering_browsers(task_id_to_clean)
941-
# Ensure the instance tracker is clean (should be handled by tool's finally block)
942-
lingering_keys = [k for k in _BROWSER_AGENT_INSTANCES if k.startswith(f"{task_id_to_clean}_")]
943-
if lingering_keys:
944-
logger.warning(
945-
f"{len(lingering_keys)} lingering browser instances found in tracker for task {task_id_to_clean} after cleanup attempt.")
946-
# Force clear them from the tracker dict
947-
for key in lingering_keys:
948-
del _BROWSER_AGENT_INSTANCES[key]
943+
task_id_to_clean = self.current_task_id
949944

950945
self.stop_event = None
951946
self.current_task_id = None
@@ -961,28 +956,6 @@ async def run(self, topic: str, task_id: Optional[str] = None) -> Dict[str, Any]
961956
"final_state": final_state if final_state else {} # Return the final state dict
962957
}
963958

964-
async def _stop_lingering_browsers(self, task_id):
965-
"""Attempts to stop any BrowserUseAgent instances associated with the task_id."""
966-
keys_to_stop = [key for key in _BROWSER_AGENT_INSTANCES if key.startswith(f"{task_id}_")]
967-
if not keys_to_stop:
968-
return
969-
970-
logger.warning(
971-
f"Found {len(keys_to_stop)} potentially lingering browser agents for task {task_id}. Attempting stop...")
972-
for key in keys_to_stop:
973-
agent_instance = _BROWSER_AGENT_INSTANCES.get(key)
974-
if agent_instance and hasattr(agent_instance, 'stop'):
975-
try:
976-
# Assuming BU agent has an async stop method
977-
await agent_instance.stop()
978-
logger.info(f"Called stop() on browser agent instance {key}")
979-
except Exception as e:
980-
logger.error(f"Error calling stop() on browser agent instance {key}: {e}")
981-
# Instance should be removed by the finally block in run_single_browser_task
982-
# but we ensure removal here too.
983-
if key in _BROWSER_AGENT_INSTANCES:
984-
del _BROWSER_AGENT_INSTANCES[key]
985-
986959
def stop(self):
987960
"""Signals the currently running agent task to stop."""
988961
if not self.current_task_id or not self.stop_event:
@@ -991,14 +964,7 @@ def stop(self):
991964

992965
logger.info(f"Stop requested for task ID: {self.current_task_id}")
993966
self.stop_event.set() # Signal the stop event
967+
self.stopped = True
994968

995-
# Additionally, try to stop the browser agents directly
996-
# Need to run this async in the background or manage event loops carefully
997-
async def do_stop_browsers():
998-
await self._stop_lingering_browsers(self.current_task_id)
999-
1000-
try:
1001-
loop = asyncio.get_running_loop()
1002-
loop.create_task(do_stop_browsers())
1003-
except RuntimeError: # No running loop in current thread
1004-
asyncio.run(do_stop_browsers())
969+
def close(self):
970+
self.stopped = False

src/utils/llm_provider.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from langchain_ollama import ChatOllama
4747
from langchain_openai import AzureChatOpenAI, ChatOpenAI
4848
from langchain_ibm import ChatWatsonx
49+
from langchain_aws import ChatBedrock
50+
from pydantic import SecretStr
4951

5052
from src.utils import config
5153

@@ -154,7 +156,7 @@ def get_llm_model(provider: str, **kwargs):
154156
:param kwargs:
155157
:return:
156158
"""
157-
if provider not in ["ollama"]:
159+
if provider not in ["ollama", "bedrock"]:
158160
env_var = f"{provider.upper()}_API_KEY"
159161
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
160162
if not api_key:
@@ -263,6 +265,23 @@ def get_llm_model(provider: str, **kwargs):
263265
azure_endpoint=base_url,
264266
api_key=api_key,
265267
)
268+
elif provider == "bedrock":
269+
if not kwargs.get("base_url", ""):
270+
access_key_id = os.getenv('AWS_ACCESS_KEY_ID', '')
271+
else:
272+
access_key_id = kwargs.get("base_url")
273+
274+
if not kwargs.get("api_key", ""):
275+
api_key = os.getenv('AWS_SECRET_ACCESS_KEY', '')
276+
else:
277+
api_key = kwargs.get("api_key")
278+
return ChatBedrock(
279+
model=kwargs.get("model_name", 'anthropic.claude-3-5-sonnet-20241022-v2:0'),
280+
region=kwargs.get("bedrock_region", 'us-west-2'), # with higher quota
281+
aws_access_key_id=SecretStr(access_key_id),
282+
aws_secret_access_key=SecretStr(api_key),
283+
temperature=kwargs.get("temperature", 0.0),
284+
)
266285
elif provider == "alibaba":
267286
if not kwargs.get("base_url", ""):
268287
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")

0 commit comments

Comments
 (0)