Skip to content

Commit 045aea9

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Support API-Key for MCP Tool authentication
PiperOrigin-RevId: 776641474
1 parent 20279d9 commit 045aea9

File tree

2 files changed

+249
-13
lines changed

2 files changed

+249
-13
lines changed

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
from __future__ import annotations
1616

1717
import base64
18-
import json
1918
import logging
2019
from typing import Optional
2120

21+
from fastapi.openapi.models import APIKeyIn
2222
from google.genai.types import FunctionDeclaration
23-
from google.oauth2.credentials import Credentials
2423
from typing_extensions import override
2524

2625
from .._gemini_schema_util import _to_gemini_schema
@@ -58,6 +57,9 @@ class MCPTool(BaseAuthenticatedTool):
5857
5958
Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
6059
call the tool.
60+
61+
Note: For API key authentication, only header-based API keys are supported.
62+
Query and cookie-based API keys will result in authentication errors.
6163
"""
6264

6365
def __init__(
@@ -134,7 +136,19 @@ async def _run_async_impl(
134136
async def _get_headers(
135137
self, tool_context: ToolContext, credential: AuthCredential
136138
) -> Optional[dict[str, str]]:
137-
headers = None
139+
"""Extracts authentication headers from credentials.
140+
141+
Args:
142+
tool_context: The tool context of the current invocation.
143+
credential: The authentication credential to process.
144+
145+
Returns:
146+
Dictionary of headers to add to the request, or None if no auth.
147+
148+
Raises:
149+
ValueError: If API key authentication is configured for non-header location.
150+
"""
151+
headers: Optional[dict[str, str]] = None
138152
if credential:
139153
if credential.oauth2:
140154
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
@@ -167,10 +181,33 @@ async def _get_headers(
167181
)
168182
}
169183
elif credential.api_key:
170-
# For API keys, we'll add them as headers since MCP typically uses header-based auth
171-
# The specific header name would depend on the API, using a common default
172-
# TODO Allow user to specify the header name for API keys.
173-
headers = {"X-API-Key": credential.api_key}
184+
if (
185+
not self._credentials_manager
186+
or not self._credentials_manager._auth_config
187+
):
188+
error_msg = (
189+
"Cannot find corresponding auth scheme for API key credential"
190+
f" {credential}"
191+
)
192+
logger.error(error_msg)
193+
raise ValueError(error_msg)
194+
elif (
195+
self._credentials_manager._auth_config.auth_scheme.in_
196+
!= APIKeyIn.header
197+
):
198+
error_msg = (
199+
"MCPTool only supports header-based API key authentication."
200+
" Configured location:"
201+
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
202+
)
203+
logger.error(error_msg)
204+
raise ValueError(error_msg)
205+
else:
206+
headers = {
207+
self._credentials_manager._auth_config.auth_scheme.name: (
208+
credential.api_key
209+
)
210+
}
174211
elif credential.service_account:
175212
# Service accounts should be exchanged for access tokens before reaching this point
176213
logger.warning(

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 205 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414

1515
import sys
16-
from typing import Any
17-
from typing import Dict
1816
from unittest.mock import AsyncMock
1917
from unittest.mock import Mock
2018
from unittest.mock import patch
@@ -268,8 +266,102 @@ async def test_get_headers_http_basic(self):
268266
assert headers == {"Authorization": f"Basic {expected_encoded}"}
269267

270268
@pytest.mark.asyncio
271-
async def test_get_headers_api_key(self):
272-
"""Test header generation for API Key credentials."""
269+
async def test_get_headers_api_key_with_valid_header_scheme(self):
270+
"""Test header generation for API Key credentials with header-based auth scheme."""
271+
from fastapi.openapi.models import APIKey
272+
from fastapi.openapi.models import APIKeyIn
273+
from google.adk.auth.auth_schemes import AuthSchemeType
274+
275+
# Create auth scheme for header-based API key
276+
auth_scheme = APIKey(**{
277+
"type": AuthSchemeType.apiKey,
278+
"in": APIKeyIn.header,
279+
"name": "X-Custom-API-Key",
280+
})
281+
auth_credential = AuthCredential(
282+
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
283+
)
284+
285+
tool = MCPTool(
286+
mcp_tool=self.mock_mcp_tool,
287+
mcp_session_manager=self.mock_session_manager,
288+
auth_scheme=auth_scheme,
289+
auth_credential=auth_credential,
290+
)
291+
292+
tool_context = Mock(spec=ToolContext)
293+
headers = await tool._get_headers(tool_context, auth_credential)
294+
295+
assert headers == {"X-Custom-API-Key": "my_api_key"}
296+
297+
@pytest.mark.asyncio
298+
async def test_get_headers_api_key_with_query_scheme_raises_error(self):
299+
"""Test that API Key with query-based auth scheme raises ValueError."""
300+
from fastapi.openapi.models import APIKey
301+
from fastapi.openapi.models import APIKeyIn
302+
from google.adk.auth.auth_schemes import AuthSchemeType
303+
304+
# Create auth scheme for query-based API key (not supported)
305+
auth_scheme = APIKey(**{
306+
"type": AuthSchemeType.apiKey,
307+
"in": APIKeyIn.query,
308+
"name": "api_key",
309+
})
310+
auth_credential = AuthCredential(
311+
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
312+
)
313+
314+
tool = MCPTool(
315+
mcp_tool=self.mock_mcp_tool,
316+
mcp_session_manager=self.mock_session_manager,
317+
auth_scheme=auth_scheme,
318+
auth_credential=auth_credential,
319+
)
320+
321+
tool_context = Mock(spec=ToolContext)
322+
323+
with pytest.raises(
324+
ValueError,
325+
match="MCPTool only supports header-based API key authentication",
326+
):
327+
await tool._get_headers(tool_context, auth_credential)
328+
329+
@pytest.mark.asyncio
330+
async def test_get_headers_api_key_with_cookie_scheme_raises_error(self):
331+
"""Test that API Key with cookie-based auth scheme raises ValueError."""
332+
from fastapi.openapi.models import APIKey
333+
from fastapi.openapi.models import APIKeyIn
334+
from google.adk.auth.auth_schemes import AuthSchemeType
335+
336+
# Create auth scheme for cookie-based API key (not supported)
337+
auth_scheme = APIKey(**{
338+
"type": AuthSchemeType.apiKey,
339+
"in": APIKeyIn.cookie,
340+
"name": "session_id",
341+
})
342+
auth_credential = AuthCredential(
343+
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
344+
)
345+
346+
tool = MCPTool(
347+
mcp_tool=self.mock_mcp_tool,
348+
mcp_session_manager=self.mock_session_manager,
349+
auth_scheme=auth_scheme,
350+
auth_credential=auth_credential,
351+
)
352+
353+
tool_context = Mock(spec=ToolContext)
354+
355+
with pytest.raises(
356+
ValueError,
357+
match="MCPTool only supports header-based API key authentication",
358+
):
359+
await tool._get_headers(tool_context, auth_credential)
360+
361+
@pytest.mark.asyncio
362+
async def test_get_headers_api_key_without_auth_config_raises_error(self):
363+
"""Test that API Key without auth config raises ValueError."""
364+
# Create tool without auth scheme/config
273365
tool = MCPTool(
274366
mcp_tool=self.mock_mcp_tool,
275367
mcp_session_manager=self.mock_session_manager,
@@ -278,11 +370,37 @@ async def test_get_headers_api_key(self):
278370
credential = AuthCredential(
279371
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
280372
)
373+
tool_context = Mock(spec=ToolContext)
374+
375+
with pytest.raises(
376+
ValueError,
377+
match="Cannot find corresponding auth scheme for API key credential",
378+
):
379+
await tool._get_headers(tool_context, credential)
380+
381+
@pytest.mark.asyncio
382+
async def test_get_headers_api_key_without_credentials_manager_raises_error(
383+
self,
384+
):
385+
"""Test that API Key without credentials manager raises ValueError."""
386+
tool = MCPTool(
387+
mcp_tool=self.mock_mcp_tool,
388+
mcp_session_manager=self.mock_session_manager,
389+
)
281390

391+
# Manually set credentials manager to None to simulate error condition
392+
tool._credentials_manager = None
393+
394+
credential = AuthCredential(
395+
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
396+
)
282397
tool_context = Mock(spec=ToolContext)
283-
headers = await tool._get_headers(tool_context, credential)
284398

285-
assert headers == {"X-API-Key": "my_api_key"}
399+
with pytest.raises(
400+
ValueError,
401+
match="Cannot find corresponding auth scheme for API key credential",
402+
):
403+
await tool._get_headers(tool_context, credential)
286404

287405
@pytest.mark.asyncio
288406
async def test_get_headers_no_credential(self):
@@ -318,6 +436,48 @@ async def test_get_headers_service_account(self):
318436
# Should return None as service account credentials are not supported for direct header generation
319437
assert headers is None
320438

439+
@pytest.mark.asyncio
440+
async def test_run_async_impl_with_api_key_header_auth(self):
441+
"""Test running tool with API key header authentication end-to-end."""
442+
from fastapi.openapi.models import APIKey
443+
from fastapi.openapi.models import APIKeyIn
444+
from google.adk.auth.auth_schemes import AuthSchemeType
445+
446+
# Create auth scheme for header-based API key
447+
auth_scheme = APIKey(**{
448+
"type": AuthSchemeType.apiKey,
449+
"in": APIKeyIn.header,
450+
"name": "X-Service-API-Key",
451+
})
452+
auth_credential = AuthCredential(
453+
auth_type=AuthCredentialTypes.API_KEY, api_key="test_service_key"
454+
)
455+
456+
tool = MCPTool(
457+
mcp_tool=self.mock_mcp_tool,
458+
mcp_session_manager=self.mock_session_manager,
459+
auth_scheme=auth_scheme,
460+
auth_credential=auth_credential,
461+
)
462+
463+
# Mock the session response
464+
expected_response = {"result": "authenticated_success"}
465+
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
466+
467+
tool_context = Mock(spec=ToolContext)
468+
args = {"param1": "test_value"}
469+
470+
result = await tool._run_async_impl(
471+
args=args, tool_context=tool_context, credential=auth_credential
472+
)
473+
474+
assert result == expected_response
475+
# Check that headers were passed correctly with custom API key header
476+
self.mock_session_manager.create_session.assert_called_once()
477+
call_args = self.mock_session_manager.create_session.call_args
478+
headers = call_args[1]["headers"]
479+
assert headers == {"X-Service-API-Key": "test_service_key"}
480+
321481
@pytest.mark.asyncio
322482
async def test_run_async_impl_retry_decorator(self):
323483
"""Test that the retry decorator is applied correctly."""
@@ -350,6 +510,45 @@ async def test_get_headers_http_custom_scheme(self):
350510

351511
assert headers == {"Authorization": "custom custom_token"}
352512

513+
@pytest.mark.asyncio
514+
async def test_get_headers_api_key_error_logging(self):
515+
"""Test that API key errors are logged correctly."""
516+
from fastapi.openapi.models import APIKey
517+
from fastapi.openapi.models import APIKeyIn
518+
from google.adk.auth.auth_schemes import AuthSchemeType
519+
520+
# Create auth scheme for query-based API key (not supported)
521+
auth_scheme = APIKey(**{
522+
"type": AuthSchemeType.apiKey,
523+
"in": APIKeyIn.query,
524+
"name": "api_key",
525+
})
526+
auth_credential = AuthCredential(
527+
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
528+
)
529+
530+
tool = MCPTool(
531+
mcp_tool=self.mock_mcp_tool,
532+
mcp_session_manager=self.mock_session_manager,
533+
auth_scheme=auth_scheme,
534+
auth_credential=auth_credential,
535+
)
536+
537+
tool_context = Mock(spec=ToolContext)
538+
539+
# Test with logging
540+
with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger:
541+
with pytest.raises(ValueError):
542+
await tool._get_headers(tool_context, auth_credential)
543+
544+
# Verify error was logged
545+
mock_logger.error.assert_called_once()
546+
logged_message = mock_logger.error.call_args[0][0]
547+
assert (
548+
"MCPTool only supports header-based API key authentication"
549+
in logged_message
550+
)
551+
353552
def test_init_validation(self):
354553
"""Test that initialization validates required parameters."""
355554
# This test ensures that the MCPTool properly handles its dependencies

0 commit comments

Comments
 (0)