Skip to content

test: improve test coverage #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/a2a/client/grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def send_message(
metadata=proto_utils.ToProto.metadata(request.metadata),
)
)
if response.task:
if response.HasField('task'):
return proto_utils.FromProto.task(response.task)
return proto_utils.FromProto.message(response.msg)

Expand All @@ -87,7 +87,7 @@ async def send_message_streaming(
`TaskArtifactUpdateEvent` objects as they are received in the
stream.
"""
stream = self.stub.SendStreamingMessage(
stream = await self.stub.SendStreamingMessage(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this should be awaited

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment!

Looking at the implementation here:

it seems like channel might be a grpc.Channel.

However, based on how it’s being used in the current client implementation:

it looks like channel is intended to be a grpc.aio.Channel, in which case the SendStreamingMessage call should indeed be awaited.

That said, I’m not 100% certain either, so it’d be great to get confirmation from whoever originally implemented this part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a2a_pb2.SendMessageRequest(
request=proto_utils.ToProto.message(request.message),
configuration=proto_utils.ToProto.message_send_configuration(
Expand Down
6 changes: 4 additions & 2 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@


# Regexp patterns for matching
_TASK_NAME_MATCH = r'tasks/(\w+)'
_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotificationConfigs/(\w+)'
_TASK_NAME_MATCH = r'tasks/([a-zA-Z0-9_.-]+)'
_TASK_PUSH_CONFIG_NAME_MATCH = (
r'tasks/([a-zA-Z0-9_.-]+)/pushNotificationConfigs/([a-zA-Z0-9_.-]+)'
)


class ToProto:
Expand Down
270 changes: 267 additions & 3 deletions tests/client/test_grpc_client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock

import grpc
import pytest

from a2a.client import A2AGrpcClient
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
from a2a.types import (
AgentCapabilities,
AgentCard,
Artifact,
Message,
MessageSendParams,
Part,
PushNotificationAuthenticationInfo,
PushNotificationConfig,
Role,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskPushNotificationConfig,
TaskQueryParams,
TaskState,
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
)
from a2a.utils import proto_utils
from a2a.utils.errors import ServerError


# Fixtures
Expand All @@ -30,8 +38,8 @@ def mock_grpc_stub() -> AsyncMock:
stub.SendStreamingMessage = AsyncMock()
stub.GetTask = AsyncMock()
stub.CancelTask = AsyncMock()
stub.CreateTaskPushNotification = AsyncMock()
stub.GetTaskPushNotification = AsyncMock()
stub.CreateTaskPushNotificationConfig = AsyncMock()
stub.GetTaskPushNotificationConfig = AsyncMock()
return stub


Expand Down Expand Up @@ -90,6 +98,78 @@ def sample_message() -> Message:
)


@pytest.fixture
def sample_artifact() -> Artifact:
"""Provides a sample Artifact object."""
return Artifact(
artifactId='artifact-1',
name='example.txt',
description='An example artifact',
parts=[Part(root=TextPart(text='Hi there'))],
metadata={},
extensions=[],
)


@pytest.fixture
def sample_task_status_update_event() -> TaskStatusUpdateEvent:
"""Provides a sample TaskStatusUpdateEvent."""
return TaskStatusUpdateEvent(
taskId='task-1',
contextId='ctx-1',
status=TaskStatus(state=TaskState.working),
final=False,
metadata={},
)


@pytest.fixture
def sample_task_artifact_update_event(
sample_artifact,
) -> TaskArtifactUpdateEvent:
"""Provides a sample TaskArtifactUpdateEvent."""
return TaskArtifactUpdateEvent(
taskId='task-1',
contextId='ctx-1',
artifact=sample_artifact,
append=True,
last_chunk=True,
metadata={},
)


@pytest.fixture
def sample_authentication_info() -> PushNotificationAuthenticationInfo:
"""Provides a sample AuthenticationInfo object."""
return PushNotificationAuthenticationInfo(
schemes=['apikey', 'oauth2'], credentials='secret-token'
)


@pytest.fixture
def sample_push_notification_config(
sample_authentication_info: PushNotificationAuthenticationInfo,
) -> PushNotificationConfig:
"""Provides a sample PushNotificationConfig object."""
return PushNotificationConfig(
id='config-1',
url='https://example.com/notify',
token='example-token',
authentication=sample_authentication_info,
)


@pytest.fixture
def sample_task_push_notification_config(
sample_push_notification_config: PushNotificationConfig,
) -> TaskPushNotificationConfig:
"""Provides a sample TaskPushNotificationConfig object."""
return TaskPushNotificationConfig(
taskId='task-1',
pushNotificationConfig=sample_push_notification_config,
)


@pytest.mark.asyncio
async def test_send_message_task_response(
grpc_client: A2AGrpcClient,
Expand All @@ -109,6 +189,76 @@ async def test_send_message_task_response(
assert response.id == sample_task.id


@pytest.mark.asyncio
async def test_send_message_message_response(
grpc_client: A2AGrpcClient,
mock_grpc_stub: AsyncMock,
sample_message_send_params: MessageSendParams,
sample_message: Message,
):
"""Test send_message that returns a Message."""
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
msg=proto_utils.ToProto.message(sample_message)
)

response = await grpc_client.send_message(sample_message_send_params)

mock_grpc_stub.SendMessage.assert_awaited_once()
assert isinstance(response, Message)
assert response.messageId == sample_message.messageId


@pytest.mark.asyncio
async def test_send_message_streaming(
grpc_client: A2AGrpcClient,
mock_grpc_stub: AsyncMock,
sample_message_send_params: MessageSendParams,
sample_message: Message,
sample_task: Task,
sample_task_status_update_event: TaskStatusUpdateEvent,
sample_task_artifact_update_event: TaskArtifactUpdateEvent,
):
"""Test send_message_streaming that yields responses."""
stream = MagicMock()
stream.read = AsyncMock(
side_effect=[
a2a_pb2.StreamResponse(
msg=proto_utils.ToProto.message(sample_message)
),
a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(sample_task)),
a2a_pb2.StreamResponse(
status_update=proto_utils.ToProto.task_status_update_event(
sample_task_status_update_event
)
),
a2a_pb2.StreamResponse(
artifact_update=proto_utils.ToProto.task_artifact_update_event(
sample_task_artifact_update_event
)
),
grpc.aio.EOF,
]
)
mock_grpc_stub.SendStreamingMessage.return_value = stream

responses = [
response
async for response in grpc_client.send_message_streaming(
sample_message_send_params
)
]

mock_grpc_stub.SendStreamingMessage.assert_awaited_once()
assert isinstance(responses[0], Message)
assert responses[0].messageId == sample_message.messageId
assert isinstance(responses[1], Task)
assert responses[1].id == sample_task.id
assert isinstance(responses[2], TaskStatusUpdateEvent)
assert responses[2].taskId == sample_task_status_update_event.taskId
assert isinstance(responses[3], TaskArtifactUpdateEvent)
assert responses[3].taskId == sample_task_artifact_update_event.taskId


@pytest.mark.asyncio
async def test_get_task(
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task
Expand Down Expand Up @@ -143,3 +293,117 @@ async def test_cancel_task(
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}')
)
assert response.status.state == TaskState.canceled


@pytest.mark.asyncio
async def test_set_task_callback_with_valid_task(
grpc_client: A2AGrpcClient,
mock_grpc_stub: AsyncMock,
sample_task_push_notification_config: TaskPushNotificationConfig,
):
"""Test setting a task push notification config with a valid task id."""
task_id = 'task-1'
config_id = 'config-1'
mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = (
a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'tasks/{task_id}',
config_id=config_id,
config=proto_utils.ToProto.task_push_notification_config(
sample_task_push_notification_config
),
)
)

response = await grpc_client.set_task_callback(
sample_task_push_notification_config
)

mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with(
a2a_pb2.CreateTaskPushNotificationConfigRequest(
config=proto_utils.ToProto.task_push_notification_config(
sample_task_push_notification_config
),
)
)
assert response.taskId == task_id


@pytest.mark.asyncio
async def test_set_task_callback_with_invalid_task(
grpc_client: A2AGrpcClient,
mock_grpc_stub: AsyncMock,
sample_task_push_notification_config: TaskPushNotificationConfig,
):
"""Test setting a task push notification config with a invalid task id."""
task_id = 'task-1'
config_id = 'config-1'
mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = (
a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'invalid-path-to-tasks/{task_id}',
config_id=config_id,
config=proto_utils.ToProto.task_push_notification_config(
sample_task_push_notification_config
),
)
)

with pytest.raises(ServerError) as exc_info:
await grpc_client.set_task_callback(
sample_task_push_notification_config
)
assert 'No task for' in exc_info.value.error.message


@pytest.mark.asyncio
async def test_get_task_callback_with_valid_task(
grpc_client: A2AGrpcClient,
mock_grpc_stub: AsyncMock,
sample_task_push_notification_config: TaskPushNotificationConfig,
):
"""Test retrieving a task push notification config with a valid task id."""
task_id = 'task-1'
config_id = 'config-1'
mock_grpc_stub.GetTaskPushNotificationConfig.return_value = (
a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'tasks/{task_id}',
config_id=config_id,
config=proto_utils.ToProto.task_push_notification_config(
sample_task_push_notification_config
),
)
)
params = TaskIdParams(id=sample_task_push_notification_config.taskId)

response = await grpc_client.get_task_callback(params)

mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once_with(
a2a_pb2.GetTaskPushNotificationConfigRequest(
name=f'tasks/{params.id}/pushNotification/undefined',
)
)
assert response.taskId == task_id


@pytest.mark.asyncio
async def test_get_task_callback_with_invalid_task(
grpc_client: A2AGrpcClient,
mock_grpc_stub: AsyncMock,
sample_task_push_notification_config: TaskPushNotificationConfig,
):
"""Test retrieving a task push notification config with a invalid task id."""
task_id = 'task-1'
config_id = 'config-1'
mock_grpc_stub.GetTaskPushNotificationConfig.return_value = (
a2a_pb2.CreateTaskPushNotificationConfigRequest(
parent=f'invalid-path-to-tasks/{task_id}',
config_id=config_id,
config=proto_utils.ToProto.task_push_notification_config(
sample_task_push_notification_config
),
)
)
params = TaskIdParams(id=sample_task_push_notification_config.taskId)

with pytest.raises(ServerError) as exc_info:
await grpc_client.get_task_callback(params)
assert 'No task for' in exc_info.value.error.message
Loading