|
5 | 5 | from dataclasses import dataclass, field, fields |
6 | 6 | from os import getenv |
7 | 7 | from types import GeneratorType |
8 | | -from typing import Any, Callable, Dict, List, Optional, Union, cast |
| 8 | +from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, Union, cast |
9 | 9 | from uuid import uuid4 |
10 | 10 |
|
11 | 11 | from pydantic import BaseModel |
@@ -248,6 +248,118 @@ def result_generator(): |
248 | 248 | logger.warning(f"Workflow.run() should only return RunResponse objects, got: {type(result)}") |
249 | 249 | return None |
250 | 250 |
|
| 251 | + # Add to workflow.py after the run_workflow method |
| 252 | + |
| 253 | + async def arun_workflow(self, **kwargs: Any): |
| 254 | + """Run the Workflow asynchronously""" |
| 255 | + |
| 256 | + # Set mode, debug, workflow_id, session_id, initialize memory |
| 257 | + self.set_storage_mode() |
| 258 | + self.set_debug() |
| 259 | + self.set_monitoring() |
| 260 | + self.set_workflow_id() # Ensure workflow_id is set |
| 261 | + self.set_session_id() |
| 262 | + self.initialize_memory() |
| 263 | + |
| 264 | + # Update workflow_id for all agents before registration |
| 265 | + for field_name, value in self.__class__.__dict__.items(): |
| 266 | + if isinstance(value, Agent): |
| 267 | + value.initialize_agent() |
| 268 | + value.workflow_id = self.workflow_id |
| 269 | + |
| 270 | + if isinstance(value, Team): |
| 271 | + value.initialize_team() |
| 272 | + value.workflow_id = self.workflow_id |
| 273 | + |
| 274 | + # Register the workflow, which will also register agents and teams |
| 275 | + await self.aregister_workflow() |
| 276 | + |
| 277 | + # Create a run_id |
| 278 | + self.run_id = str(uuid4()) |
| 279 | + |
| 280 | + # Set run_input, run_response |
| 281 | + self.run_input = kwargs |
| 282 | + self.run_response = RunResponse(run_id=self.run_id, session_id=self.session_id, workflow_id=self.workflow_id) |
| 283 | + |
| 284 | + # Read existing session from storage |
| 285 | + self.read_from_storage() |
| 286 | + |
| 287 | + # Update the session_id for all Agent instances |
| 288 | + self.update_agent_session_ids() |
| 289 | + |
| 290 | + log_debug(f"Workflow Run Start: {self.run_id}", center=True) |
| 291 | + try: |
| 292 | + self._subclass_run = cast(Callable, self._subclass_run) |
| 293 | + result = await self._subclass_run(**kwargs) |
| 294 | + except Exception as e: |
| 295 | + logger.error(f"Workflow.arun() failed: {e}") |
| 296 | + raise e |
| 297 | + |
| 298 | + # Handle async iterator results |
| 299 | + if isinstance(result, (AsyncIterator, AsyncGenerator)): |
| 300 | + # Initialize the run_response content |
| 301 | + self.run_response.content = "" |
| 302 | + |
| 303 | + async def result_generator(): |
| 304 | + self.run_response = cast(RunResponse, self.run_response) |
| 305 | + if isinstance(self.memory, WorkflowMemory): |
| 306 | + self.memory = cast(WorkflowMemory, self.memory) |
| 307 | + elif isinstance(self.memory, Memory): |
| 308 | + self.memory = cast(Memory, self.memory) |
| 309 | + |
| 310 | + async for item in result: |
| 311 | + if isinstance(item, RunResponse): |
| 312 | + # Update the run_id, session_id and workflow_id of the RunResponse |
| 313 | + item.run_id = self.run_id |
| 314 | + item.session_id = self.session_id |
| 315 | + item.workflow_id = self.workflow_id |
| 316 | + |
| 317 | + # Update the run_response with the content from the result |
| 318 | + if item.content is not None and isinstance(item.content, str): |
| 319 | + self.run_response.content += item.content |
| 320 | + else: |
| 321 | + logger.warning(f"Workflow.arun() should only yield RunResponse objects, got: {type(item)}") |
| 322 | + yield item |
| 323 | + |
| 324 | + # Add the run to the memory |
| 325 | + if isinstance(self.memory, WorkflowMemory): |
| 326 | + self.memory.add_run(WorkflowRun(input=self.run_input, response=self.run_response)) |
| 327 | + elif isinstance(self.memory, Memory): |
| 328 | + self.memory.add_run(session_id=self.session_id, run=self.run_response) # type: ignore |
| 329 | + # Write this run to the database |
| 330 | + self.write_to_storage() |
| 331 | + log_debug(f"Workflow Run End: {self.run_id}", center=True) |
| 332 | + |
| 333 | + return result_generator() |
| 334 | + # Handle single RunResponse result |
| 335 | + elif isinstance(result, RunResponse): |
| 336 | + # Update the result with the run_id, session_id and workflow_id of the workflow run |
| 337 | + result.run_id = self.run_id |
| 338 | + result.session_id = self.session_id |
| 339 | + result.workflow_id = self.workflow_id |
| 340 | + |
| 341 | + # Update the run_response with the content from the result |
| 342 | + if result.content is not None and isinstance(result.content, str): |
| 343 | + self.run_response.content = result.content |
| 344 | + |
| 345 | + # Add the run to the memory |
| 346 | + if isinstance(self.memory, WorkflowMemory): |
| 347 | + self.memory.add_run(WorkflowRun(input=self.run_input, response=self.run_response)) |
| 348 | + elif isinstance(self.memory, Memory): |
| 349 | + self.memory.add_run(session_id=self.session_id, run=self.run_response) # type: ignore |
| 350 | + # Write this run to the database |
| 351 | + self.write_to_storage() |
| 352 | + log_debug(f"Workflow Run End: {self.run_id}", center=True) |
| 353 | + return result |
| 354 | + else: |
| 355 | + logger.warning(f"Workflow.arun() should only return RunResponse objects, got: {type(result)}") |
| 356 | + return None |
| 357 | + |
| 358 | + async def arun(self, **kwargs: Any): |
| 359 | + """Async version of run() that calls arun_workflow()""" |
| 360 | + logger.error(f"{self.__class__.__name__}.arun() method not implemented.") |
| 361 | + return |
| 362 | + |
251 | 363 | def set_storage_mode(self): |
252 | 364 | if self.storage is not None: |
253 | 365 | self.storage.mode = "workflow" |
@@ -294,11 +406,17 @@ def update_run_method(self): |
294 | 406 | # First, check if the subclass has a run method |
295 | 407 | # If the run() method has been overridden by the subclass, |
296 | 408 | # then self.__class__.run is not Workflow.run will be True |
297 | | - if self.__class__.run is not Workflow.run: |
298 | | - # Store the original run method bound to the instance in self._subclass_run |
299 | | - self._subclass_run = self.__class__.run.__get__(self) |
300 | | - # Get the parameters of the run method |
301 | | - sig = inspect.signature(self.__class__.run) |
| 409 | + if self.__class__.run is not Workflow.run or self.__class__.arun is not Workflow.arun: |
| 410 | + # Store the original run methods bound to the instance |
| 411 | + if self.__class__.run is not Workflow.run: |
| 412 | + self._subclass_run = self.__class__.run.__get__(self) |
| 413 | + # Get the parameters of the sync run method |
| 414 | + sig = inspect.signature(self.__class__.run) |
| 415 | + if self.__class__.arun is not Workflow.arun: |
| 416 | + self._subclass_run = self.__class__.arun.__get__(self) |
| 417 | + # Get the parameters of the async run method |
| 418 | + sig = inspect.signature(self.__class__.arun) |
| 419 | + |
302 | 420 | # Convert parameters to a serializable format |
303 | 421 | self._run_parameters = { |
304 | 422 | param_name: { |
@@ -662,6 +780,31 @@ def _deep_copy_field(self, field_name: str, field_value: Any) -> Any: |
662 | 780 | # For other types, return as is |
663 | 781 | return field_value |
664 | 782 |
|
| 783 | + async def aregister_workflow(self, force: bool = False) -> None: |
| 784 | + """Async version of register_workflow""" |
| 785 | + self.set_monitoring() |
| 786 | + if not self.monitoring: |
| 787 | + return |
| 788 | + |
| 789 | + if not self.workflow_id: |
| 790 | + self.set_workflow_id() |
| 791 | + |
| 792 | + try: |
| 793 | + from agno.api.schemas.workflows import WorkflowCreate |
| 794 | + from agno.api.workflows import acreate_workflow |
| 795 | + |
| 796 | + workflow_config = self.to_config_dict() |
| 797 | + # Register the workflow as an app |
| 798 | + await acreate_workflow( |
| 799 | + workflow=WorkflowCreate( |
| 800 | + name=self.name, workflow_id=self.workflow_id, app_id=self.app_id, config=workflow_config |
| 801 | + ) |
| 802 | + ) |
| 803 | + |
| 804 | + log_debug(f"Registered workflow: {self.name} (ID: {self.workflow_id})") |
| 805 | + except Exception as e: |
| 806 | + log_warning(f"Failed to register workflow: {e}") |
| 807 | + |
665 | 808 | def register_workflow(self, force: bool = False) -> None: |
666 | 809 | """Register this workflow with Agno's platform. |
667 | 810 |
|
|
0 commit comments