Skip to content

fix: Add Input Validation for Task Context IDs in new_task Function #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def new_task(request: Message) -> Task:

Raises:
TypeError: If the message role is None.
ValueError: If the message parts are empty or if any part has empty content.
ValueError: If the message parts are empty, if any part has empty content, or if the provided context_id is invalid.
"""
if not request.role:
raise TypeError('Message role cannot be None')
Expand All @@ -28,12 +28,25 @@ def new_task(request: Message) -> Task:
if isinstance(part.root, TextPart) and not part.root.text:
raise ValueError('TextPart content cannot be empty')

context_id_str = request.context_id
if context_id_str is not None:
try:
# Validate that the provided context_id is a valid UUID
uuid.UUID(context_id_str)
context_id = context_id_str
except (ValueError, AttributeError, TypeError):
# Catch a variety of potential issues with the UUID validation
raise ValueError(
f"Invalid context_id: '{context_id_str}' is not a valid UUID."
)
else:
# Generate a new UUID if no context_id is provided
context_id = str(uuid.uuid4())

return Task(
status=TaskStatus(state=TaskState.submitted),
id=(request.task_id if request.task_id else str(uuid.uuid4())),
context_id=(
request.context_id if request.context_id else str(uuid.uuid4())
),
context_id=context_id,
history=[request],
)

Expand Down
39 changes: 39 additions & 0 deletions tests/utils/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,45 @@ def test_completed_task_invalid_artifact_type(self):
history=[],
)

def test_new_task_with_invalid_context_id(self):
with pytest.raises(
ValueError,
match="Invalid context_id: 'not-a-uuid' is not a valid UUID.",
):
new_task(
Message(
role=Role.user,
parts=[Part(root=TextPart(text='test message'))],
message_id=str(uuid.uuid4()),
context_id='not-a-uuid',
)
)

def test_new_task_with_empty_string_context_id(self):
with pytest.raises(
ValueError, match="Invalid context_id: '' is not a valid UUID."
):
new_task(
Message(
role=Role.user,
parts=[Part(root=TextPart(text='test message'))],
message_id=str(uuid.uuid4()),
context_id='',
)
)

def test_new_task_with_valid_context_id(self):
valid_uuid = '123e4567-e89b-12d3-a456-426614174000'
task = new_task(
Message(
role=Role.user,
parts=[Part(root=TextPart(text='test message'))],
message_id=str(uuid.uuid4()),
context_id=valid_uuid,
)
)
self.assertEqual(task.context_id, valid_uuid)


if __name__ == '__main__':
unittest.main()
Loading