Skip to content

Commit 1022093

Browse files
pankaj-bindgemini-code-assist[bot]pstephengoogle
authored
feat(server): Add lock to TaskUpdater to prevent race conditions (#279)
# Description This PR adds a lock to the `TaskUpdater` to prevent race conditions when updating tasks that are in a terminal state (e.g., completed, failed). This ensures updates are handled atomically, making task state management more robust. ### **File Changes** * `src/a2a/server/tasks/task_updater.py` * `tests/server/tasks/test_task_updater.py` ### **Benefits** * **Robustness**: Prevents race conditions during concurrent task updates. * **Stability**: Ensures reliable task execution in production. * **Test Coverage**: Adds new tests for concurrent scenarios. Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests and linter pass (Run `nox -s format` from the repository root to format) - [x] Appropriate docs were updated (if necessary) Fixes #278 🦕 --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: pstephengoogle <pstephen@google.com>
1 parent cb08973 commit 1022093

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

src/a2a/server/tasks/task_updater.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import uuid
23

34
from datetime import datetime, timezone
@@ -33,6 +34,14 @@ def __init__(self, event_queue: EventQueue, task_id: str, context_id: str):
3334
self.event_queue = event_queue
3435
self.task_id = task_id
3536
self.context_id = context_id
37+
self._lock = asyncio.Lock()
38+
self._terminal_state_reached = False
39+
self._terminal_states = {
40+
TaskState.completed,
41+
TaskState.canceled,
42+
TaskState.failed,
43+
TaskState.rejected,
44+
}
3645

3746
async def update_status(
3847
self,
@@ -49,21 +58,26 @@ async def update_status(
4958
final: If True, indicates this is the final status update for the task.
5059
timestamp: Optional ISO 8601 datetime string. Defaults to current time.
5160
"""
52-
current_timestamp = (
53-
timestamp if timestamp else datetime.now(timezone.utc).isoformat()
54-
)
55-
await self.event_queue.enqueue_event(
56-
TaskStatusUpdateEvent(
57-
taskId=self.task_id,
58-
contextId=self.context_id,
59-
final=final,
60-
status=TaskStatus(
61-
state=state,
62-
message=message,
63-
timestamp=current_timestamp,
64-
),
61+
async with self._lock:
62+
if self._terminal_state_reached:
63+
raise RuntimeError(f"Task {self.task_id} is already in a terminal state.")
64+
if state in self._terminal_states:
65+
self._terminal_state_reached = True
66+
final = True
67+
68+
current_timestamp = timestamp if timestamp else datetime.now(timezone.utc).isoformat()
69+
await self.event_queue.enqueue_event(
70+
TaskStatusUpdateEvent(
71+
taskId=self.task_id,
72+
contextId=self.context_id,
73+
final=final,
74+
status=TaskStatus(
75+
state=state,
76+
message=message,
77+
timestamp=current_timestamp,
78+
),
79+
)
6580
)
66-
)
6781

6882
async def add_artifact( # noqa: PLR0913
6983
self,

tests/server/tasks/test_task_updater.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import asyncio
12
import uuid
2-
33
from unittest.mock import AsyncMock, patch
44

55
import pytest
@@ -505,3 +505,31 @@ async def test_cancel_with_message(task_updater, event_queue, sample_message):
505505
assert event.status.state == TaskState.canceled
506506
assert event.final is True
507507
assert event.status.message == sample_message
508+
509+
510+
@pytest.mark.asyncio
511+
async def test_update_status_raises_error_if_terminal_state_reached(task_updater, event_queue):
512+
await task_updater.complete()
513+
event_queue.reset_mock()
514+
with pytest.raises(RuntimeError):
515+
await task_updater.start_work()
516+
event_queue.enqueue_event.assert_not_called()
517+
518+
519+
@pytest.mark.asyncio
520+
async def test_concurrent_updates_race_condition(event_queue):
521+
task_updater = TaskUpdater(
522+
event_queue=event_queue,
523+
task_id="test-task-id",
524+
context_id="test-context-id",
525+
)
526+
tasks = [
527+
task_updater.complete(),
528+
task_updater.failed(),
529+
]
530+
results = await asyncio.gather(*tasks, return_exceptions=True)
531+
successes = [r for r in results if not isinstance(r, Exception)]
532+
failures = [r for r in results if isinstance(r, RuntimeError)]
533+
assert len(successes) == 1
534+
assert len(failures) == 1
535+
assert event_queue.enqueue_event.call_count == 1

0 commit comments

Comments
 (0)