Skip to content

Commit 4b39e50

Browse files
committed
test: improve coverage
Signed-off-by: Shingo OKAWA <shingo.okawa.g.h.c@gmail.com>
1 parent dc95e2a commit 4b39e50

File tree

2 files changed

+154
-3
lines changed

2 files changed

+154
-3
lines changed

src/a2a/client/grpc_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def send_message(
6464
metadata=proto_utils.ToProto.metadata(request.metadata),
6565
)
6666
)
67-
if response.task:
67+
if response.HasField('task'):
6868
return proto_utils.FromProto.task(response.task)
6969
return proto_utils.FromProto.message(response.msg)
7070

@@ -87,7 +87,7 @@ async def send_message_streaming(
8787
`TaskArtifactUpdateEvent` objects as they are received in the
8888
stream.
8989
"""
90-
stream = self.stub.SendStreamingMessage(
90+
stream = await self.stub.SendStreamingMessage(
9191
a2a_pb2.SendMessageRequest(
9292
request=proto_utils.ToProto.message(request.message),
9393
configuration=proto_utils.ToProto.message_send_configuration(

tests/client/test_grpc_client.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
1-
from unittest.mock import AsyncMock
1+
from unittest.mock import AsyncMock, MagicMock
22

3+
import grpc
34
import pytest
45

6+
from google.protobuf import struct_pb2
7+
58
from a2a.client import A2AGrpcClient
69
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
710
from a2a.types import (
811
AgentCapabilities,
912
AgentCard,
13+
Artifact,
1014
Message,
1115
MessageSendParams,
1216
Part,
17+
PushNotificationAuthenticationInfo,
18+
PushNotificationConfig,
1319
Role,
1420
Task,
21+
TaskArtifactUpdateEvent,
1522
TaskIdParams,
23+
TaskPushNotificationConfig,
1624
TaskQueryParams,
1725
TaskState,
1826
TaskStatus,
27+
TaskStatusUpdateEvent,
1928
TextPart,
2029
)
2130
from a2a.utils import proto_utils
@@ -90,6 +99,78 @@ def sample_message() -> Message:
9099
)
91100

92101

102+
@pytest.fixture
103+
def sample_artifact() -> Artifact:
104+
"""Provides a sample Artifact object."""
105+
return Artifact(
106+
artifactId='artifact-1',
107+
name='example.txt',
108+
description='An example artifact',
109+
parts=[Part(root=TextPart(text='Hi there'))],
110+
metadata=struct_pb2.Struct(),
111+
extensions=[],
112+
)
113+
114+
115+
@pytest.fixture
116+
def sample_task_status_update_event() -> TaskStatusUpdateEvent:
117+
"""Provides a sample TaskStatusUpdateEvent."""
118+
return TaskStatusUpdateEvent(
119+
taskId='task-1',
120+
contextId='ctx-1',
121+
status=TaskStatus(state=TaskState.working),
122+
final=False,
123+
metadata=struct_pb2.Struct(),
124+
)
125+
126+
127+
@pytest.fixture
128+
def sample_task_artifact_update_event(
129+
sample_artifact,
130+
) -> TaskArtifactUpdateEvent:
131+
"""Provides a sample TaskArtifactUpdateEvent."""
132+
return TaskArtifactUpdateEvent(
133+
taskId='task-1',
134+
contextId='ctx-1',
135+
artifact=sample_artifact,
136+
append=True,
137+
last_chunk=True,
138+
metadata=struct_pb2.Struct(),
139+
)
140+
141+
142+
@pytest.fixture
143+
def sample_authentication_info() -> PushNotificationAuthenticationInfo:
144+
"""Provides a sample AuthenticationInfo object."""
145+
return PushNotificationAuthenticationInfo(
146+
schemes=['apikey', 'oauth2'], credentials='secret-token'
147+
)
148+
149+
150+
@pytest.fixture
151+
def sample_push_notification_config(
152+
sample_authentication_info: PushNotificationAuthenticationInfo,
153+
) -> PushNotificationConfig:
154+
"""Provides a sample PushNotificationConfig object."""
155+
return PushNotificationConfig(
156+
id='config-1',
157+
url='https://example.com/notify',
158+
token='example-token',
159+
authentication=sample_authentication_info,
160+
)
161+
162+
163+
@pytest.fixture
164+
def sample_task_push_notification_config(
165+
sample_push_notification_config: PushNotificationConfig,
166+
) -> TaskPushNotificationConfig:
167+
"""Provides a sample TaskPushNotificationConfig object."""
168+
return TaskPushNotificationConfig(
169+
name='task-push-notification-config-name-1',
170+
push_notification_config=sample_push_notification_config,
171+
)
172+
173+
93174
@pytest.mark.asyncio
94175
async def test_send_message_task_response(
95176
grpc_client: A2AGrpcClient,
@@ -109,6 +190,76 @@ async def test_send_message_task_response(
109190
assert response.id == sample_task.id
110191

111192

193+
@pytest.mark.asyncio
194+
async def test_send_message_message_response(
195+
grpc_client: A2AGrpcClient,
196+
mock_grpc_stub: AsyncMock,
197+
sample_message_send_params: MessageSendParams,
198+
sample_message: Message,
199+
):
200+
"""Test send_message that returns a Message."""
201+
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
202+
msg=proto_utils.ToProto.message(sample_message)
203+
)
204+
205+
response = await grpc_client.send_message(sample_message_send_params)
206+
207+
mock_grpc_stub.SendMessage.assert_awaited_once()
208+
assert isinstance(response, Message)
209+
assert response.messageId == sample_message.messageId
210+
211+
212+
@pytest.mark.asyncio
213+
async def test_send_message_streaming(
214+
grpc_client: A2AGrpcClient,
215+
mock_grpc_stub: AsyncMock,
216+
sample_message_send_params: MessageSendParams,
217+
sample_message: Message,
218+
sample_task: Task,
219+
sample_task_status_update_event: TaskStatusUpdateEvent,
220+
sample_task_artifact_update_event: TaskArtifactUpdateEvent,
221+
):
222+
"""Test send_message_streaming that yields responses."""
223+
stream = MagicMock()
224+
stream.read = AsyncMock(
225+
side_effect=[
226+
a2a_pb2.StreamResponse(
227+
msg=proto_utils.ToProto.message(sample_message)
228+
),
229+
a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(sample_task)),
230+
a2a_pb2.StreamResponse(
231+
status_update=proto_utils.ToProto.task_status_update_event(
232+
sample_task_status_update_event
233+
)
234+
),
235+
a2a_pb2.StreamResponse(
236+
artifact_update=proto_utils.ToProto.task_artifact_update_event(
237+
sample_task_artifact_update_event
238+
)
239+
),
240+
grpc.aio.EOF,
241+
]
242+
)
243+
mock_grpc_stub.SendStreamingMessage.return_value = stream
244+
245+
responses = [
246+
response
247+
async for response in grpc_client.send_message_streaming(
248+
sample_message_send_params
249+
)
250+
]
251+
252+
mock_grpc_stub.SendStreamingMessage.assert_awaited_once()
253+
assert isinstance(responses[0], Message)
254+
assert responses[0].messageId == sample_message.messageId
255+
assert isinstance(responses[1], Task)
256+
assert responses[1].id == sample_task.id
257+
assert isinstance(responses[2], TaskStatusUpdateEvent)
258+
assert responses[2].taskId == sample_task_status_update_event.taskId
259+
assert isinstance(responses[3], TaskArtifactUpdateEvent)
260+
assert responses[3].taskId == sample_task_artifact_update_event.taskId
261+
262+
112263
@pytest.mark.asyncio
113264
async def test_get_task(
114265
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task

0 commit comments

Comments
 (0)