Skip to content

Commit c4a324d

Browse files
authored
fix: Add validation for empty artifact lists in completed_task (#308)
This update addresses an issue where the `completed_task` function in `a2a/utils/task.py` did not perform any validation on the `artifacts` list. This could lead to unexpected behavior if an empty or invalid list was provided. This change introduces a validation check to ensure that the `artifacts` list is a non-empty list of `Artifact` objects, raising a `ValueError` if the validation fails. **Changes:** - Modified `a2a/utils/task.py` to add a validation check for the `artifacts` parameter in the `completed_task` function. - Updated `tests/utils/test_task.py` to include tests for the new validation logic, covering cases with empty lists and lists containing invalid items.
1 parent b94b8f5 commit c4a324d

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

src/a2a/utils/task.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def completed_task(
5555
Returns:
5656
A `Task` object with status set to 'completed'.
5757
"""
58+
if not artifacts or not all(isinstance(a, Artifact) for a in artifacts):
59+
raise ValueError(
60+
'artifacts must be a non-empty list of Artifact objects'
61+
)
62+
5863
if history is None:
5964
history = []
6065
return Task(

tests/utils/test_task.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
from unittest.mock import patch
55

6-
from a2a.types import Message, Part, Role, TextPart
6+
import pytest
7+
8+
from a2a.types import Artifact, Message, Part, Role, TextPart
79
from a2a.utils.task import completed_task, new_task
810

911

@@ -57,7 +59,12 @@ def test_new_task_initial_message_in_history(self):
5759
def test_completed_task_status(self):
5860
task_id = str(uuid.uuid4())
5961
context_id = str(uuid.uuid4())
60-
artifacts = [] # Artifacts should be of type Artifact
62+
artifacts = [
63+
Artifact(
64+
artifactId='artifact_1',
65+
parts=[Part(root=TextPart(text='some content'))],
66+
)
67+
]
6168
task = completed_task(
6269
task_id=task_id,
6370
context_id=context_id,
@@ -69,7 +76,12 @@ def test_completed_task_status(self):
6976
def test_completed_task_assigns_ids_and_artifacts(self):
7077
task_id = str(uuid.uuid4())
7178
context_id = str(uuid.uuid4())
72-
artifacts = [] # Artifacts should be of type Artifact
79+
artifacts = [
80+
Artifact(
81+
artifactId='artifact_1',
82+
parts=[Part(root=TextPart(text='some content'))],
83+
)
84+
]
7385
task = completed_task(
7486
task_id=task_id,
7587
context_id=context_id,
@@ -83,7 +95,12 @@ def test_completed_task_assigns_ids_and_artifacts(self):
8395
def test_completed_task_empty_history_if_not_provided(self):
8496
task_id = str(uuid.uuid4())
8597
context_id = str(uuid.uuid4())
86-
artifacts = [] # Artifacts should be of type Artifact
98+
artifacts = [
99+
Artifact(
100+
artifactId='artifact_1',
101+
parts=[Part(root=TextPart(text='some content'))],
102+
)
103+
]
87104
task = completed_task(
88105
task_id=task_id, context_id=context_id, artifacts=artifacts
89106
)
@@ -92,7 +109,12 @@ def test_completed_task_empty_history_if_not_provided(self):
92109
def test_completed_task_uses_provided_history(self):
93110
task_id = str(uuid.uuid4())
94111
context_id = str(uuid.uuid4())
95-
artifacts = [] # Artifacts should be of type Artifact
112+
artifacts = [
113+
Artifact(
114+
artifactId='artifact_1',
115+
parts=[Part(root=TextPart(text='some content'))],
116+
)
117+
]
96118
history = [
97119
Message(
98120
role=Role.user,
@@ -132,6 +154,30 @@ def test_new_task_invalid_message_none_role(self):
132154
)
133155
new_task(msg)
134156

157+
def test_completed_task_empty_artifacts(self):
158+
with pytest.raises(
159+
ValueError,
160+
match='artifacts must be a non-empty list of Artifact objects',
161+
):
162+
completed_task(
163+
task_id='task-123',
164+
context_id='ctx-456',
165+
artifacts=[],
166+
history=[],
167+
)
168+
169+
def test_completed_task_invalid_artifact_type(self):
170+
with pytest.raises(
171+
ValueError,
172+
match='artifacts must be a non-empty list of Artifact objects',
173+
):
174+
completed_task(
175+
task_id='task-123',
176+
context_id='ctx-456',
177+
artifacts=['not an artifact'],
178+
history=[],
179+
)
180+
135181

136182
if __name__ == '__main__':
137183
unittest.main()

0 commit comments

Comments
 (0)