Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions camel/societies/workforce/workforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ class Workforce(BaseNode):
support native structured output. When disabled, the workforce
uses the native response_format parameter.
(default: :obj:`True`)
on_subtask_completed (Optional[Callable[[Task], None]], optional):
Callback function to be called when a subtask is completed.
(default: :obj:`None`)
on_subtask_failed (Optional[Callable[[Task], None]], optional):
Callback function to be called when a subtask fails.
(default: :obj:`None`)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need to add these 2 argument in this PR, could you describe more for this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was working on this issue
, but when I saw that coolbeevip was already working on it, I decided to drop it. I’m not sure how I ended up committing, but I will remove it.


Example:
>>> import asyncio
Expand Down Expand Up @@ -249,6 +255,8 @@ def __init__(
share_memory: bool = False,
use_structured_output_handler: bool = True,
task_timeout_seconds: Optional[float] = None,
on_subtask_completed: Optional[Callable[[Task], None]] = None,
on_subtask_failed: Optional[Callable[[Task], None]] = None,
) -> None:
super().__init__(description)
self._child_listening_tasks: Deque[
Expand All @@ -265,6 +273,9 @@ def __init__(
if self.use_structured_output_handler:
self.structured_handler = StructuredOutputHandler()
self.metrics_logger = WorkforceLogger(workforce_id=self.node_id)
# Optional user callbacks for subtask lifecycle notifications
self.on_subtask_completed = on_subtask_completed
self.on_subtask_failed = on_subtask_failed
self._task: Optional[Task] = None
self._pending_tasks: Deque[Task] = deque()
self._task_dependencies: Dict[str, List[str]] = {}
Expand Down Expand Up @@ -2517,6 +2528,14 @@ async def _handle_failed_task(self, task: Task) -> bool:
self._completed_tasks.append(task)
if task.id in self._assignees:
await self._channel.archive_task(task.id)
# Invoke failure callback before halting
if self.on_subtask_failed is not None:
try:
self.on_subtask_failed(task)
except Exception as cb_err:
logger.warning(
f"on_subtask_failed callback raised: {cb_err}"
)
return True

# If too many tasks are failing rapidly, also halt to prevent infinite
Expand Down Expand Up @@ -2654,6 +2673,15 @@ async def _handle_failed_task(self, task: Task) -> bool:
self._completed_tasks.append(task)
return False

# Notify failure after bookkeeping, before scheduling next work
if self.on_subtask_failed is not None:
try:
self.on_subtask_failed(task)
except Exception as cb_err:
logger.warning(
f"on_subtask_failed callback raised: {cb_err}"
)

logger.debug(
f"Task {task.id} failed and was {action_taken}. "
f"Updating dependency state."
Expand Down Expand Up @@ -2757,6 +2785,14 @@ async def _handle_completed_task(self, task: Task) -> None:
if task.id in self._assignees:
await self._channel.archive_task(task.id)

if self.on_subtask_completed is not None:
try:
self.on_subtask_completed(task)
except Exception as cb_err:
logger.warning(
f"on_subtask_completed callback raised: {cb_err}"
)

# Ensure it's in completed tasks set by updating if it exists or
# appending if it's new.
task_found_in_completed = False
Expand Down
Loading
Loading