Skip to content

Commit a227f4f

Browse files
committed
test: improve coverage
Signed-off-by: Shingo OKAWA <shingo.okawa.g.h.c@gmail.com>
1 parent a38d438 commit a227f4f

File tree

3 files changed

+274
-7
lines changed

3 files changed

+274
-7
lines changed

src/a2a/client/grpc_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def send_message(
6464
metadata=proto_utils.ToProto.metadata(request.metadata),
6565
)
6666
)
67-
if response.task:
67+
if response.HasField('task'):
6868
return proto_utils.FromProto.task(response.task)
6969
return proto_utils.FromProto.message(response.msg)
7070

@@ -87,7 +87,7 @@ async def send_message_streaming(
8787
`TaskArtifactUpdateEvent` objects as they are received in the
8888
stream.
8989
"""
90-
stream = self.stub.SendStreamingMessage(
90+
stream = await self.stub.SendStreamingMessage(
9191
a2a_pb2.SendMessageRequest(
9292
request=proto_utils.ToProto.message(request.message),
9393
configuration=proto_utils.ToProto.message_send_configuration(

src/a2a/utils/proto_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515

1616
# Regexp patterns for matching
17-
_TASK_NAME_MATCH = r'tasks/(\w+)'
18-
_TASK_PUSH_CONFIG_NAME_MATCH = r'tasks/(\w+)/pushNotificationConfigs/(\w+)'
17+
_TASK_NAME_MATCH = r'tasks/([a-zA-Z0-9_.-]+)'
18+
_TASK_PUSH_CONFIG_NAME_MATCH = (
19+
r'tasks/([a-zA-Z0-9_.-]+)/pushNotificationConfigs/([a-zA-Z0-9_.-]+)'
20+
)
1921

2022

2123
class ToProto:
@@ -631,6 +633,7 @@ def task_push_notification_config(
631633
request: a2a_pb2.CreateTaskPushNotificationConfigRequest,
632634
) -> types.TaskPushNotificationConfig:
633635
m = re.match(_TASK_NAME_MATCH, request.parent)
636+
print(m)
634637
if not m:
635638
raise ServerError(
636639
error=types.InvalidParamsError(

tests/client/test_grpc_client.py

Lines changed: 267 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
1-
from unittest.mock import AsyncMock
1+
from unittest.mock import AsyncMock, MagicMock
22

3+
import grpc
34
import pytest
45

56
from a2a.client import A2AGrpcClient
67
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
78
from a2a.types import (
89
AgentCapabilities,
910
AgentCard,
11+
Artifact,
1012
Message,
1113
MessageSendParams,
1214
Part,
15+
PushNotificationAuthenticationInfo,
16+
PushNotificationConfig,
1317
Role,
1418
Task,
19+
TaskArtifactUpdateEvent,
1520
TaskIdParams,
21+
TaskPushNotificationConfig,
1622
TaskQueryParams,
1723
TaskState,
1824
TaskStatus,
25+
TaskStatusUpdateEvent,
1926
TextPart,
2027
)
2128
from a2a.utils import proto_utils
29+
from a2a.utils.errors import ServerError
2230

2331

2432
# Fixtures
@@ -30,8 +38,8 @@ def mock_grpc_stub() -> AsyncMock:
3038
stub.SendStreamingMessage = AsyncMock()
3139
stub.GetTask = AsyncMock()
3240
stub.CancelTask = AsyncMock()
33-
stub.CreateTaskPushNotification = AsyncMock()
34-
stub.GetTaskPushNotification = AsyncMock()
41+
stub.CreateTaskPushNotificationConfig = AsyncMock()
42+
stub.GetTaskPushNotificationConfig = AsyncMock()
3543
return stub
3644

3745

@@ -90,6 +98,78 @@ def sample_message() -> Message:
9098
)
9199

92100

101+
@pytest.fixture
102+
def sample_artifact() -> Artifact:
103+
"""Provides a sample Artifact object."""
104+
return Artifact(
105+
artifactId='artifact-1',
106+
name='example.txt',
107+
description='An example artifact',
108+
parts=[Part(root=TextPart(text='Hi there'))],
109+
metadata={},
110+
extensions=[],
111+
)
112+
113+
114+
@pytest.fixture
115+
def sample_task_status_update_event() -> TaskStatusUpdateEvent:
116+
"""Provides a sample TaskStatusUpdateEvent."""
117+
return TaskStatusUpdateEvent(
118+
taskId='task-1',
119+
contextId='ctx-1',
120+
status=TaskStatus(state=TaskState.working),
121+
final=False,
122+
metadata={},
123+
)
124+
125+
126+
@pytest.fixture
127+
def sample_task_artifact_update_event(
128+
sample_artifact,
129+
) -> TaskArtifactUpdateEvent:
130+
"""Provides a sample TaskArtifactUpdateEvent."""
131+
return TaskArtifactUpdateEvent(
132+
taskId='task-1',
133+
contextId='ctx-1',
134+
artifact=sample_artifact,
135+
append=True,
136+
last_chunk=True,
137+
metadata={},
138+
)
139+
140+
141+
@pytest.fixture
142+
def sample_authentication_info() -> PushNotificationAuthenticationInfo:
143+
"""Provides a sample AuthenticationInfo object."""
144+
return PushNotificationAuthenticationInfo(
145+
schemes=['apikey', 'oauth2'], credentials='secret-token'
146+
)
147+
148+
149+
@pytest.fixture
150+
def sample_push_notification_config(
151+
sample_authentication_info: PushNotificationAuthenticationInfo,
152+
) -> PushNotificationConfig:
153+
"""Provides a sample PushNotificationConfig object."""
154+
return PushNotificationConfig(
155+
id='config-1',
156+
url='https://example.com/notify',
157+
token='example-token',
158+
authentication=sample_authentication_info,
159+
)
160+
161+
162+
@pytest.fixture
163+
def sample_task_push_notification_config(
164+
sample_push_notification_config: PushNotificationConfig,
165+
) -> TaskPushNotificationConfig:
166+
"""Provides a sample TaskPushNotificationConfig object."""
167+
return TaskPushNotificationConfig(
168+
taskId='task-1',
169+
pushNotificationConfig=sample_push_notification_config,
170+
)
171+
172+
93173
@pytest.mark.asyncio
94174
async def test_send_message_task_response(
95175
grpc_client: A2AGrpcClient,
@@ -109,6 +189,76 @@ async def test_send_message_task_response(
109189
assert response.id == sample_task.id
110190

111191

192+
@pytest.mark.asyncio
193+
async def test_send_message_message_response(
194+
grpc_client: A2AGrpcClient,
195+
mock_grpc_stub: AsyncMock,
196+
sample_message_send_params: MessageSendParams,
197+
sample_message: Message,
198+
):
199+
"""Test send_message that returns a Message."""
200+
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
201+
msg=proto_utils.ToProto.message(sample_message)
202+
)
203+
204+
response = await grpc_client.send_message(sample_message_send_params)
205+
206+
mock_grpc_stub.SendMessage.assert_awaited_once()
207+
assert isinstance(response, Message)
208+
assert response.messageId == sample_message.messageId
209+
210+
211+
@pytest.mark.asyncio
212+
async def test_send_message_streaming(
213+
grpc_client: A2AGrpcClient,
214+
mock_grpc_stub: AsyncMock,
215+
sample_message_send_params: MessageSendParams,
216+
sample_message: Message,
217+
sample_task: Task,
218+
sample_task_status_update_event: TaskStatusUpdateEvent,
219+
sample_task_artifact_update_event: TaskArtifactUpdateEvent,
220+
):
221+
"""Test send_message_streaming that yields responses."""
222+
stream = MagicMock()
223+
stream.read = AsyncMock(
224+
side_effect=[
225+
a2a_pb2.StreamResponse(
226+
msg=proto_utils.ToProto.message(sample_message)
227+
),
228+
a2a_pb2.StreamResponse(task=proto_utils.ToProto.task(sample_task)),
229+
a2a_pb2.StreamResponse(
230+
status_update=proto_utils.ToProto.task_status_update_event(
231+
sample_task_status_update_event
232+
)
233+
),
234+
a2a_pb2.StreamResponse(
235+
artifact_update=proto_utils.ToProto.task_artifact_update_event(
236+
sample_task_artifact_update_event
237+
)
238+
),
239+
grpc.aio.EOF,
240+
]
241+
)
242+
mock_grpc_stub.SendStreamingMessage.return_value = stream
243+
244+
responses = [
245+
response
246+
async for response in grpc_client.send_message_streaming(
247+
sample_message_send_params
248+
)
249+
]
250+
251+
mock_grpc_stub.SendStreamingMessage.assert_awaited_once()
252+
assert isinstance(responses[0], Message)
253+
assert responses[0].messageId == sample_message.messageId
254+
assert isinstance(responses[1], Task)
255+
assert responses[1].id == sample_task.id
256+
assert isinstance(responses[2], TaskStatusUpdateEvent)
257+
assert responses[2].taskId == sample_task_status_update_event.taskId
258+
assert isinstance(responses[3], TaskArtifactUpdateEvent)
259+
assert responses[3].taskId == sample_task_artifact_update_event.taskId
260+
261+
112262
@pytest.mark.asyncio
113263
async def test_get_task(
114264
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task
@@ -143,3 +293,117 @@ async def test_cancel_task(
143293
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}')
144294
)
145295
assert response.status.state == TaskState.canceled
296+
297+
298+
@pytest.mark.asyncio
299+
async def test_set_task_callback_with_valid_task(
300+
grpc_client: A2AGrpcClient,
301+
mock_grpc_stub: AsyncMock,
302+
sample_task_push_notification_config: TaskPushNotificationConfig,
303+
):
304+
"""Test setting a task push notification config with a valid task id."""
305+
task_id = 'task-1'
306+
config_id = 'config-1'
307+
mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = (
308+
a2a_pb2.CreateTaskPushNotificationConfigRequest(
309+
parent=f'tasks/{task_id}',
310+
config_id=config_id,
311+
config=proto_utils.ToProto.task_push_notification_config(
312+
sample_task_push_notification_config
313+
),
314+
)
315+
)
316+
317+
response = await grpc_client.set_task_callback(
318+
sample_task_push_notification_config
319+
)
320+
321+
mock_grpc_stub.CreateTaskPushNotificationConfig.assert_awaited_once_with(
322+
a2a_pb2.CreateTaskPushNotificationConfigRequest(
323+
config=proto_utils.ToProto.task_push_notification_config(
324+
sample_task_push_notification_config
325+
),
326+
)
327+
)
328+
assert response.taskId == task_id
329+
330+
331+
@pytest.mark.asyncio
332+
async def test_set_task_callback_with_invalid_task(
333+
grpc_client: A2AGrpcClient,
334+
mock_grpc_stub: AsyncMock,
335+
sample_task_push_notification_config: TaskPushNotificationConfig,
336+
):
337+
"""Test setting a task push notification config with a invalid task id."""
338+
task_id = 'task-1'
339+
config_id = 'config-1'
340+
mock_grpc_stub.CreateTaskPushNotificationConfig.return_value = (
341+
a2a_pb2.CreateTaskPushNotificationConfigRequest(
342+
parent=f'invalid-path-to-tasks/{task_id}',
343+
config_id=config_id,
344+
config=proto_utils.ToProto.task_push_notification_config(
345+
sample_task_push_notification_config
346+
),
347+
)
348+
)
349+
350+
with pytest.raises(ServerError) as exc_info:
351+
await grpc_client.set_task_callback(
352+
sample_task_push_notification_config
353+
)
354+
assert 'No task for' in exc_info.value.error.message
355+
356+
357+
@pytest.mark.asyncio
358+
async def test_get_task_callback_with_valid_task(
359+
grpc_client: A2AGrpcClient,
360+
mock_grpc_stub: AsyncMock,
361+
sample_task_push_notification_config: TaskPushNotificationConfig,
362+
):
363+
"""Test retrieving a task push notification config with a valid task id."""
364+
task_id = 'task-1'
365+
config_id = 'config-1'
366+
mock_grpc_stub.GetTaskPushNotificationConfig.return_value = (
367+
a2a_pb2.CreateTaskPushNotificationConfigRequest(
368+
parent=f'tasks/{task_id}',
369+
config_id=config_id,
370+
config=proto_utils.ToProto.task_push_notification_config(
371+
sample_task_push_notification_config
372+
),
373+
)
374+
)
375+
params = TaskIdParams(id=sample_task_push_notification_config.taskId)
376+
377+
response = await grpc_client.get_task_callback(params)
378+
379+
mock_grpc_stub.GetTaskPushNotificationConfig.assert_awaited_once_with(
380+
a2a_pb2.GetTaskPushNotificationConfigRequest(
381+
name=f'tasks/{params.id}/pushNotification/undefined',
382+
)
383+
)
384+
assert response.taskId == task_id
385+
386+
387+
@pytest.mark.asyncio
388+
async def test_get_task_callback_with_invalid_task(
389+
grpc_client: A2AGrpcClient,
390+
mock_grpc_stub: AsyncMock,
391+
sample_task_push_notification_config: TaskPushNotificationConfig,
392+
):
393+
"""Test retrieving a task push notification config with a invalid task id."""
394+
task_id = 'task-1'
395+
config_id = 'config-1'
396+
mock_grpc_stub.GetTaskPushNotificationConfig.return_value = (
397+
a2a_pb2.CreateTaskPushNotificationConfigRequest(
398+
parent=f'invalid-path-to-tasks/{task_id}',
399+
config_id=config_id,
400+
config=proto_utils.ToProto.task_push_notification_config(
401+
sample_task_push_notification_config
402+
),
403+
)
404+
)
405+
params = TaskIdParams(id=sample_task_push_notification_config.taskId)
406+
407+
with pytest.raises(ServerError) as exc_info:
408+
await grpc_client.get_task_callback(params)
409+
assert 'No task for' in exc_info.value.error.message

0 commit comments

Comments
 (0)