Skip to content

Commit cc65122

Browse files
kausmeowsdirkbrnd
andauthored
feat: add arun in workflows (#3396)
## Summary add `arun` in workflows (If applicable, issue number: #____) ## Type of change - [ ] Bug fix - [x] New feature - [ ] Breaking change - [ ] Improvement - [ ] Model update - [ ] Other: --- ## Checklist - [x] Code complies with style guidelines - [x] Ran format/validation scripts (`./scripts/format.sh` and `./scripts/validate.sh`) - [x] Self-review completed - [x] Documentation updated (comments, docstrings) - [x] Examples and guides: Relevant cookbook examples have been included or updated (if applicable) - [x] Tested in clean environment - [ ] Tests added/updated (if applicable) --- ## Additional Notes Add any important context (deployment instructions, screenshots, security considerations, etc.) --------- Co-authored-by: Dirk Brand <dirkbrnd@gmail.com>
1 parent cd0283b commit cc65122

File tree

2 files changed

+278
-6
lines changed

2 files changed

+278
-6
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Please install dependencies using:
2+
pip install openai newspaper4k lxml_html_clean agno httpx
3+
"""
4+
5+
import asyncio
6+
import json
7+
from typing import AsyncIterator
8+
9+
import httpx
10+
from agno.agent import Agent, RunResponse
11+
from agno.tools.newspaper4k import Newspaper4kTools
12+
from agno.utils.log import logger
13+
from agno.utils.pprint import pprint_run_response
14+
from agno.workflow import RunEvent, Workflow
15+
16+
17+
class AsyncHackerNewsReporter(Workflow):
18+
description: str = (
19+
"Get the top stories from Hacker News and write a report on them."
20+
)
21+
22+
hn_agent: Agent = Agent(
23+
description="Get the top stories from hackernews. "
24+
"Share all possible information, including url, score, title and summary if available.",
25+
show_tool_calls=True,
26+
)
27+
28+
writer: Agent = Agent(
29+
tools=[Newspaper4kTools()],
30+
description="Write an engaging report on the top stories from hackernews.",
31+
instructions=[
32+
"You will be provided with top stories and their links.",
33+
"Carefully read each article and think about the contents",
34+
"Then generate a final New York Times worthy article",
35+
"Break the article into sections and provide key takeaways at the end.",
36+
"Make sure the title is catchy and engaging.",
37+
"Share score, title, url and summary of every article.",
38+
"Give the section relevant titles and provide details/facts/processes in each section."
39+
"Ignore articles that you cannot read or understand.",
40+
"REMEMBER: you are writing for the New York Times, so the quality of the article is important.",
41+
],
42+
)
43+
44+
async def get_top_hackernews_stories(self, num_stories: int = 10) -> str:
45+
"""Use this function to get top stories from Hacker News.
46+
47+
Args:
48+
num_stories (int): Number of stories to return. Defaults to 10.
49+
50+
Returns:
51+
str: JSON string of top stories.
52+
"""
53+
async with httpx.AsyncClient() as client:
54+
# Fetch top story IDs
55+
response = await client.get(
56+
"https://hacker-news.firebaseio.com/v0/topstories.json"
57+
)
58+
story_ids = response.json()
59+
60+
# Fetch story details concurrently
61+
tasks = [
62+
client.get(
63+
f"https://hacker-news.firebaseio.com/v0/item/{story_id}.json"
64+
)
65+
for story_id in story_ids[:num_stories]
66+
]
67+
responses = await asyncio.gather(*tasks)
68+
69+
stories = []
70+
for response in responses:
71+
story = response.json()
72+
story["username"] = story["by"]
73+
stories.append(story)
74+
75+
return json.dumps(stories)
76+
77+
async def arun(self, num_stories: int = 5) -> AsyncIterator[RunResponse]:
78+
# Set the tools for hn_agent here to avoid circular reference
79+
self.hn_agent.tools = [self.get_top_hackernews_stories]
80+
81+
logger.info(f"Getting top {num_stories} stories from HackerNews.")
82+
top_stories: RunResponse = await self.hn_agent.arun(num_stories=num_stories)
83+
if top_stories is None or not top_stories.content:
84+
yield RunResponse(
85+
run_id=self.run_id,
86+
content="Sorry, could not get the top stories.",
87+
event=RunEvent.workflow_completed,
88+
)
89+
return
90+
91+
logger.info("Reading each story and writing a report.")
92+
# Get the async iterator from writer.arun()
93+
writer_response = await self.writer.arun(top_stories.content, stream=True)
94+
95+
# Stream the writer's response directly
96+
async for response in writer_response:
97+
if response.content:
98+
yield RunResponse(
99+
content=response.content, event=response.event, run_id=self.run_id
100+
)
101+
102+
103+
if __name__ == "__main__":
104+
import asyncio
105+
106+
async def main():
107+
# Initialize the workflow
108+
workflow = AsyncHackerNewsReporter(debug_mode=False)
109+
110+
# Run the workflow and collect the final response
111+
final_content = []
112+
try:
113+
async for response in workflow.arun(num_stories=5):
114+
if response.content:
115+
final_content.append(response.content)
116+
except Exception as e:
117+
logger.error(f"Error running workflow: {e}")
118+
return
119+
120+
# Create final response with combined content
121+
if final_content:
122+
final_response = RunResponse(
123+
content="".join(final_content), event=RunEvent.workflow_completed
124+
)
125+
# Pretty print the final response
126+
pprint_run_response(final_response, markdown=True, show_time=True)
127+
128+
# Run the async main function
129+
asyncio.run(main())

libs/agno/agno/workflow/workflow.py

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field, fields
66
from os import getenv
77
from types import GeneratorType
8-
from typing import Any, Callable, Dict, List, Optional, Union, cast
8+
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, Union, cast
99
from uuid import uuid4
1010

1111
from pydantic import BaseModel
@@ -248,6 +248,118 @@ def result_generator():
248248
logger.warning(f"Workflow.run() should only return RunResponse objects, got: {type(result)}")
249249
return None
250250

251+
# Add to workflow.py after the run_workflow method
252+
253+
async def arun_workflow(self, **kwargs: Any):
254+
"""Run the Workflow asynchronously"""
255+
256+
# Set mode, debug, workflow_id, session_id, initialize memory
257+
self.set_storage_mode()
258+
self.set_debug()
259+
self.set_monitoring()
260+
self.set_workflow_id() # Ensure workflow_id is set
261+
self.set_session_id()
262+
self.initialize_memory()
263+
264+
# Update workflow_id for all agents before registration
265+
for field_name, value in self.__class__.__dict__.items():
266+
if isinstance(value, Agent):
267+
value.initialize_agent()
268+
value.workflow_id = self.workflow_id
269+
270+
if isinstance(value, Team):
271+
value.initialize_team()
272+
value.workflow_id = self.workflow_id
273+
274+
# Register the workflow, which will also register agents and teams
275+
await self.aregister_workflow()
276+
277+
# Create a run_id
278+
self.run_id = str(uuid4())
279+
280+
# Set run_input, run_response
281+
self.run_input = kwargs
282+
self.run_response = RunResponse(run_id=self.run_id, session_id=self.session_id, workflow_id=self.workflow_id)
283+
284+
# Read existing session from storage
285+
self.read_from_storage()
286+
287+
# Update the session_id for all Agent instances
288+
self.update_agent_session_ids()
289+
290+
log_debug(f"Workflow Run Start: {self.run_id}", center=True)
291+
try:
292+
self._subclass_run = cast(Callable, self._subclass_run)
293+
result = await self._subclass_run(**kwargs)
294+
except Exception as e:
295+
logger.error(f"Workflow.arun() failed: {e}")
296+
raise e
297+
298+
# Handle async iterator results
299+
if isinstance(result, (AsyncIterator, AsyncGenerator)):
300+
# Initialize the run_response content
301+
self.run_response.content = ""
302+
303+
async def result_generator():
304+
self.run_response = cast(RunResponse, self.run_response)
305+
if isinstance(self.memory, WorkflowMemory):
306+
self.memory = cast(WorkflowMemory, self.memory)
307+
elif isinstance(self.memory, Memory):
308+
self.memory = cast(Memory, self.memory)
309+
310+
async for item in result:
311+
if isinstance(item, RunResponse):
312+
# Update the run_id, session_id and workflow_id of the RunResponse
313+
item.run_id = self.run_id
314+
item.session_id = self.session_id
315+
item.workflow_id = self.workflow_id
316+
317+
# Update the run_response with the content from the result
318+
if item.content is not None and isinstance(item.content, str):
319+
self.run_response.content += item.content
320+
else:
321+
logger.warning(f"Workflow.arun() should only yield RunResponse objects, got: {type(item)}")
322+
yield item
323+
324+
# Add the run to the memory
325+
if isinstance(self.memory, WorkflowMemory):
326+
self.memory.add_run(WorkflowRun(input=self.run_input, response=self.run_response))
327+
elif isinstance(self.memory, Memory):
328+
self.memory.add_run(session_id=self.session_id, run=self.run_response) # type: ignore
329+
# Write this run to the database
330+
self.write_to_storage()
331+
log_debug(f"Workflow Run End: {self.run_id}", center=True)
332+
333+
return result_generator()
334+
# Handle single RunResponse result
335+
elif isinstance(result, RunResponse):
336+
# Update the result with the run_id, session_id and workflow_id of the workflow run
337+
result.run_id = self.run_id
338+
result.session_id = self.session_id
339+
result.workflow_id = self.workflow_id
340+
341+
# Update the run_response with the content from the result
342+
if result.content is not None and isinstance(result.content, str):
343+
self.run_response.content = result.content
344+
345+
# Add the run to the memory
346+
if isinstance(self.memory, WorkflowMemory):
347+
self.memory.add_run(WorkflowRun(input=self.run_input, response=self.run_response))
348+
elif isinstance(self.memory, Memory):
349+
self.memory.add_run(session_id=self.session_id, run=self.run_response) # type: ignore
350+
# Write this run to the database
351+
self.write_to_storage()
352+
log_debug(f"Workflow Run End: {self.run_id}", center=True)
353+
return result
354+
else:
355+
logger.warning(f"Workflow.arun() should only return RunResponse objects, got: {type(result)}")
356+
return None
357+
358+
async def arun(self, **kwargs: Any):
359+
"""Async version of run() that calls arun_workflow()"""
360+
logger.error(f"{self.__class__.__name__}.arun() method not implemented.")
361+
return
362+
251363
def set_storage_mode(self):
252364
if self.storage is not None:
253365
self.storage.mode = "workflow"
@@ -294,11 +406,17 @@ def update_run_method(self):
294406
# First, check if the subclass has a run method
295407
# If the run() method has been overridden by the subclass,
296408
# then self.__class__.run is not Workflow.run will be True
297-
if self.__class__.run is not Workflow.run:
298-
# Store the original run method bound to the instance in self._subclass_run
299-
self._subclass_run = self.__class__.run.__get__(self)
300-
# Get the parameters of the run method
301-
sig = inspect.signature(self.__class__.run)
409+
if self.__class__.run is not Workflow.run or self.__class__.arun is not Workflow.arun:
410+
# Store the original run methods bound to the instance
411+
if self.__class__.run is not Workflow.run:
412+
self._subclass_run = self.__class__.run.__get__(self)
413+
# Get the parameters of the sync run method
414+
sig = inspect.signature(self.__class__.run)
415+
if self.__class__.arun is not Workflow.arun:
416+
self._subclass_run = self.__class__.arun.__get__(self)
417+
# Get the parameters of the async run method
418+
sig = inspect.signature(self.__class__.arun)
419+
302420
# Convert parameters to a serializable format
303421
self._run_parameters = {
304422
param_name: {
@@ -662,6 +780,31 @@ def _deep_copy_field(self, field_name: str, field_value: Any) -> Any:
662780
# For other types, return as is
663781
return field_value
664782

783+
async def aregister_workflow(self, force: bool = False) -> None:
784+
"""Async version of register_workflow"""
785+
self.set_monitoring()
786+
if not self.monitoring:
787+
return
788+
789+
if not self.workflow_id:
790+
self.set_workflow_id()
791+
792+
try:
793+
from agno.api.schemas.workflows import WorkflowCreate
794+
from agno.api.workflows import acreate_workflow
795+
796+
workflow_config = self.to_config_dict()
797+
# Register the workflow as an app
798+
await acreate_workflow(
799+
workflow=WorkflowCreate(
800+
name=self.name, workflow_id=self.workflow_id, app_id=self.app_id, config=workflow_config
801+
)
802+
)
803+
804+
log_debug(f"Registered workflow: {self.name} (ID: {self.workflow_id})")
805+
except Exception as e:
806+
log_warning(f"Failed to register workflow: {e}")
807+
665808
def register_workflow(self, force: bool = False) -> None:
666809
"""Register this workflow with Agno's platform.
667810

0 commit comments

Comments
 (0)