Skip to content

Commit c7f2dd0

Browse files
test: Add new tests to increase coverage (#195)
Co-authored-by: kthota-g <kcthota@google.com>
1 parent aa63b98 commit c7f2dd0

File tree

11 files changed

+424
-20
lines changed

11 files changed

+424
-20
lines changed

.github/actions/spelling/excludes.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,5 @@
8888
CHANGELOG.md
8989
noxfile.py
9090
^src/a2a/grpc/
91+
^tests/
9192
.pre-commit-config.yaml

.github/workflows/unit-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ jobs:
2828
- name: Install dependencies
2929
run: uv sync --dev
3030
- name: Run tests and check coverage
31-
run: uv run pytest --cov=a2a --cov-report=xml --cov-fail-under=85
31+
run: uv run pytest --cov=a2a --cov-report=xml --cov-fail-under=90
3232
- name: Show coverage summary in log
3333
run: uv run coverage report

src/a2a/utils/proto_utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,16 @@ def authentication_info(
143143
def push_notification_config(
144144
cls, config: types.PushNotificationConfig
145145
) -> a2a_pb2.PushNotificationConfig:
146+
auth_info = (
147+
ToProto.authentication_info(config.authentication)
148+
if config.authentication
149+
else None
150+
)
146151
return a2a_pb2.PushNotificationConfig(
147152
id=config.id or '',
148153
url=config.url,
149154
token=config.token,
150-
authentication=ToProto.authentication_info(config.authentication),
155+
authentication=auth_info,
151156
)
152157

153158
@classmethod
@@ -185,7 +190,9 @@ def message_send_configuration(
185190
accepted_output_modes=list(config.acceptedOutputModes),
186191
push_notification=ToProto.push_notification_config(
187192
config.pushNotificationConfig
188-
),
193+
)
194+
if config.pushNotificationConfig
195+
else None,
189196
history_length=config.historyLength,
190197
blocking=config.blocking or False,
191198
)
@@ -335,7 +342,7 @@ def security_scheme(
335342
return a2a_pb2.SecurityScheme(
336343
api_key_security_scheme=a2a_pb2.APIKeySecurityScheme(
337344
description=scheme.root.description,
338-
location=scheme.root.in_,
345+
location=scheme.root.in_.value,
339346
name=scheme.root.name,
340347
)
341348
)
@@ -548,7 +555,9 @@ def push_notification_config(
548555
id=config.id,
549556
url=config.url,
550557
token=config.token,
551-
authentication=FromProto.authentication_info(config.authentication),
558+
authentication=FromProto.authentication_info(config.authentication)
559+
if config.HasField('authentication')
560+
else None,
552561
)
553562

554563
@classmethod
@@ -568,7 +577,9 @@ def message_send_configuration(
568577
acceptedOutputModes=list(config.accepted_output_modes),
569578
pushNotificationConfig=FromProto.push_notification_config(
570579
config.push_notification
571-
),
580+
)
581+
if config.HasField('push_notification')
582+
else None,
572583
historyLength=config.history_length,
573584
blocking=config.blocking,
574585
)
@@ -720,7 +731,7 @@ def security_scheme(
720731
root=types.APIKeySecurityScheme(
721732
description=scheme.api_key_security_scheme.description,
722733
name=scheme.api_key_security_scheme.name,
723-
in_=scheme.api_key_security_scheme.location, # type: ignore[call-arg]
734+
in_=types.In(scheme.api_key_security_scheme.location), # type: ignore[call-arg]
724735
)
725736
)
726737
if scheme.HasField('http_auth_security_scheme'):

tests/client/test_grpc_client.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from unittest.mock import AsyncMock
2+
3+
import pytest
4+
5+
from a2a.client import A2AGrpcClient
6+
from a2a.grpc import a2a_pb2, a2a_pb2_grpc
7+
from a2a.types import (
8+
AgentCapabilities,
9+
AgentCard,
10+
Message,
11+
MessageSendParams,
12+
Part,
13+
Role,
14+
Task,
15+
TaskIdParams,
16+
TaskQueryParams,
17+
TaskState,
18+
TaskStatus,
19+
TextPart,
20+
)
21+
from a2a.utils import proto_utils
22+
23+
24+
# Fixtures
25+
@pytest.fixture
26+
def mock_grpc_stub() -> AsyncMock:
27+
"""Provides a mock gRPC stub with methods mocked."""
28+
stub = AsyncMock(spec=a2a_pb2_grpc.A2AServiceStub)
29+
stub.SendMessage = AsyncMock()
30+
stub.SendStreamingMessage = AsyncMock()
31+
stub.GetTask = AsyncMock()
32+
stub.CancelTask = AsyncMock()
33+
stub.CreateTaskPushNotification = AsyncMock()
34+
stub.GetTaskPushNotification = AsyncMock()
35+
return stub
36+
37+
38+
@pytest.fixture
39+
def sample_agent_card() -> AgentCard:
40+
"""Provides a minimal agent card for initialization."""
41+
return AgentCard(
42+
name='gRPC Test Agent',
43+
description='Agent for testing gRPC client',
44+
url='grpc://localhost:50051',
45+
version='1.0',
46+
capabilities=AgentCapabilities(streaming=True, pushNotifications=True),
47+
defaultInputModes=['text/plain'],
48+
defaultOutputModes=['text/plain'],
49+
skills=[],
50+
)
51+
52+
53+
@pytest.fixture
54+
def grpc_client(
55+
mock_grpc_stub: AsyncMock, sample_agent_card: AgentCard
56+
) -> A2AGrpcClient:
57+
"""Provides an A2AGrpcClient instance."""
58+
return A2AGrpcClient(grpc_stub=mock_grpc_stub, agent_card=sample_agent_card)
59+
60+
61+
@pytest.fixture
62+
def sample_message_send_params() -> MessageSendParams:
63+
"""Provides a sample MessageSendParams object."""
64+
return MessageSendParams(
65+
message=Message(
66+
role=Role.user,
67+
messageId='msg-1',
68+
parts=[Part(root=TextPart(text='Hello'))],
69+
)
70+
)
71+
72+
73+
@pytest.fixture
74+
def sample_task() -> Task:
75+
"""Provides a sample Task object."""
76+
return Task(
77+
id='task-1',
78+
contextId='ctx-1',
79+
status=TaskStatus(state=TaskState.completed),
80+
)
81+
82+
83+
@pytest.fixture
84+
def sample_message() -> Message:
85+
"""Provides a sample Message object."""
86+
return Message(
87+
role=Role.agent,
88+
messageId='msg-response',
89+
parts=[Part(root=TextPart(text='Hi there'))],
90+
)
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_send_message_task_response(
95+
grpc_client: A2AGrpcClient,
96+
mock_grpc_stub: AsyncMock,
97+
sample_message_send_params: MessageSendParams,
98+
sample_task: Task,
99+
):
100+
"""Test send_message that returns a Task."""
101+
mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse(
102+
task=proto_utils.ToProto.task(sample_task)
103+
)
104+
105+
response = await grpc_client.send_message(sample_message_send_params)
106+
107+
mock_grpc_stub.SendMessage.assert_awaited_once()
108+
assert isinstance(response, Task)
109+
assert response.id == sample_task.id
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_get_task(
114+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task
115+
):
116+
"""Test retrieving a task."""
117+
mock_grpc_stub.GetTask.return_value = proto_utils.ToProto.task(sample_task)
118+
params = TaskQueryParams(id=sample_task.id)
119+
120+
response = await grpc_client.get_task(params)
121+
122+
mock_grpc_stub.GetTask.assert_awaited_once_with(
123+
a2a_pb2.GetTaskRequest(name=f'tasks/{sample_task.id}')
124+
)
125+
assert response.id == sample_task.id
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_cancel_task(
130+
grpc_client: A2AGrpcClient, mock_grpc_stub: AsyncMock, sample_task: Task
131+
):
132+
"""Test cancelling a task."""
133+
cancelled_task = sample_task.model_copy()
134+
cancelled_task.status.state = TaskState.canceled
135+
mock_grpc_stub.CancelTask.return_value = proto_utils.ToProto.task(
136+
cancelled_task
137+
)
138+
params = TaskIdParams(id=sample_task.id)
139+
140+
response = await grpc_client.cancel_task(params)
141+
142+
mock_grpc_stub.CancelTask.assert_awaited_once_with(
143+
a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}')
144+
)
145+
assert response.status.state == TaskState.canceled

tests/server/agent_execution/test_context.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
MessageSendParams,
1111
Task,
1212
)
13+
from a2a.utils.errors import ServerError
1314

1415

1516
class TestRequestContext:
@@ -165,6 +166,33 @@ def test_check_or_generate_context_id_with_existing_context_id(
165166
assert context.context_id == existing_id
166167
assert mock_params.message.contextId == existing_id
167168

169+
def test_init_raises_error_on_task_id_mismatch(
170+
self, mock_params, mock_task
171+
):
172+
"""Test that an error is raised if provided task_id mismatches task.id."""
173+
with pytest.raises(ServerError) as exc_info:
174+
RequestContext(
175+
request=mock_params, task_id='wrong-task-id', task=mock_task
176+
)
177+
assert 'bad task id' in str(exc_info.value.error.message)
178+
179+
def test_init_raises_error_on_context_id_mismatch(
180+
self, mock_params, mock_task
181+
):
182+
"""Test that an error is raised if provided context_id mismatches task.contextId."""
183+
# Set a valid task_id to avoid that error
184+
mock_params.message.taskId = mock_task.id
185+
186+
with pytest.raises(ServerError) as exc_info:
187+
RequestContext(
188+
request=mock_params,
189+
task_id=mock_task.id,
190+
context_id='wrong-context-id',
191+
task=mock_task,
192+
)
193+
194+
assert 'bad context id' in str(exc_info.value.error.message)
195+
168196
def test_with_related_tasks_provided(self, mock_task):
169197
"""Test initialization with related tasks provided."""
170198
related_tasks = [mock_task, Mock(spec=Task)]

tests/server/apps/jsonrpc/test_jsonrpc_app.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,20 @@ def test_jsonrpc_app_build_method_abstract_raises_typeerror(
7070
# Ensure 'supportsAuthenticatedExtendedCard' attribute exists
7171
mock_agent_card.supportsAuthenticatedExtendedCard = False
7272

73-
class AbstractTester(JSONRPCApplication):
74-
# No 'build' method implemented
75-
pass
76-
77-
# Instantiating an ABC subclass that doesn't implement all abstract methods raises TypeError
73+
# This will fail at definition time if an abstract method is not implemented
7874
with pytest.raises(
7975
TypeError,
80-
match="Can't instantiate abstract class AbstractTester with abstract method build",
76+
match="Can't instantiate abstract class IncompleteJSONRPCApp with abstract method build",
8177
):
82-
# Using positional arguments for the abstract class constructor
83-
AbstractTester(mock_handler, mock_agent_card)
78+
79+
class IncompleteJSONRPCApp(JSONRPCApplication):
80+
# Intentionally not implementing 'build'
81+
def some_other_method(self):
82+
pass
83+
84+
IncompleteJSONRPCApp(
85+
agent_card=mock_agent_card, http_handler=mock_handler
86+
)
8487

8588

8689
if __name__ == '__main__':

tests/server/request_handlers/test_grpc_handler.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,87 @@ async def test_get_agent_card(
195195

196196
assert response.name == sample_agent_card.name
197197
assert response.version == sample_agent_card.version
198+
199+
200+
@pytest.mark.asyncio
201+
@pytest.mark.parametrize(
202+
'server_error, grpc_status_code, error_message_part',
203+
[
204+
(
205+
ServerError(error=types.JSONParseError()),
206+
grpc.StatusCode.INTERNAL,
207+
'JSONParseError',
208+
),
209+
(
210+
ServerError(error=types.InvalidRequestError()),
211+
grpc.StatusCode.INVALID_ARGUMENT,
212+
'InvalidRequestError',
213+
),
214+
(
215+
ServerError(error=types.MethodNotFoundError()),
216+
grpc.StatusCode.NOT_FOUND,
217+
'MethodNotFoundError',
218+
),
219+
(
220+
ServerError(error=types.InvalidParamsError()),
221+
grpc.StatusCode.INVALID_ARGUMENT,
222+
'InvalidParamsError',
223+
),
224+
(
225+
ServerError(error=types.InternalError()),
226+
grpc.StatusCode.INTERNAL,
227+
'InternalError',
228+
),
229+
(
230+
ServerError(error=types.TaskNotFoundError()),
231+
grpc.StatusCode.NOT_FOUND,
232+
'TaskNotFoundError',
233+
),
234+
(
235+
ServerError(error=types.TaskNotCancelableError()),
236+
grpc.StatusCode.UNIMPLEMENTED,
237+
'TaskNotCancelableError',
238+
),
239+
(
240+
ServerError(error=types.PushNotificationNotSupportedError()),
241+
grpc.StatusCode.UNIMPLEMENTED,
242+
'PushNotificationNotSupportedError',
243+
),
244+
(
245+
ServerError(error=types.UnsupportedOperationError()),
246+
grpc.StatusCode.UNIMPLEMENTED,
247+
'UnsupportedOperationError',
248+
),
249+
(
250+
ServerError(error=types.ContentTypeNotSupportedError()),
251+
grpc.StatusCode.UNIMPLEMENTED,
252+
'ContentTypeNotSupportedError',
253+
),
254+
(
255+
ServerError(error=types.InvalidAgentResponseError()),
256+
grpc.StatusCode.INTERNAL,
257+
'InvalidAgentResponseError',
258+
),
259+
(
260+
ServerError(error=types.JSONRPCError(code=99, message='Unknown')),
261+
grpc.StatusCode.UNKNOWN,
262+
'Unknown error',
263+
),
264+
],
265+
)
266+
async def test_abort_context_error_mapping(
267+
grpc_handler: GrpcHandler,
268+
mock_request_handler: AsyncMock,
269+
mock_grpc_context: AsyncMock,
270+
server_error,
271+
grpc_status_code,
272+
error_message_part,
273+
):
274+
mock_request_handler.on_get_task.side_effect = server_error
275+
request_proto = a2a_pb2.GetTaskRequest(name='tasks/any')
276+
await grpc_handler.GetTask(request_proto, mock_grpc_context)
277+
278+
mock_grpc_context.abort.assert_awaited_once()
279+
call_args, _ = mock_grpc_context.abort.call_args
280+
assert call_args[0] == grpc_status_code
281+
assert error_message_part in call_args[1]

0 commit comments

Comments
 (0)