Skip to content

Commit abc89d2

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Add implementation of VertexAiMemoryBankService and support in FastAPI endpoint
PiperOrigin-RevId: 775327151
1 parent 00cc8cd commit abc89d2

File tree

5 files changed

+321
-2
lines changed

5 files changed

+321
-2
lines changed

src/google/adk/cli/cli_tools_click.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,8 @@ def decorator(func):
489489
type=str,
490490
help=(
491491
"""Optional. The URI of the memory service.
492-
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service."""
492+
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service.
493+
- Use 'agentengine://<agent_engine_resource_id>' to connect to Vertex AI Memory Bank Service. e.g. agentengine://12345"""
493494
),
494495
default=None,
495496
)

src/google/adk/cli/fast_api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
7272
from ..events.event import Event
7373
from ..memory.in_memory_memory_service import InMemoryMemoryService
74+
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
7475
from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
7576
from ..runners import Runner
7677
from ..sessions.database_session_service import DatabaseSessionService
@@ -282,6 +283,16 @@ async def internal_lifespan(app: FastAPI):
282283
memory_service = VertexAiRagMemoryService(
283284
rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}'
284285
)
286+
elif memory_service_uri.startswith("agentengine://"):
287+
agent_engine_id = memory_service_uri.split("://")[1]
288+
if not agent_engine_id:
289+
raise click.ClickException("Agent engine id can not be empty.")
290+
envs.load_dotenv_for_agent("", agents_dir)
291+
memory_service = VertexAiMemoryBankService(
292+
project=os.environ["GOOGLE_CLOUD_PROJECT"],
293+
location=os.environ["GOOGLE_CLOUD_LOCATION"],
294+
agent_engine_id=agent_engine_id,
295+
)
285296
else:
286297
raise click.ClickException(
287298
"Unsupported memory service URI: %s" % memory_service_uri

src/google/adk/memory/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515

1616
from .base_memory_service import BaseMemoryService
1717
from .in_memory_memory_service import InMemoryMemoryService
18+
from .vertex_ai_memory_bank_service import VertexAiMemoryBankService
1819

1920
logger = logging.getLogger('google_adk.' + __name__)
2021

2122
__all__ = [
2223
'BaseMemoryService',
2324
'InMemoryMemoryService',
25+
'VertexAiMemoryBankService',
2426
]
2527

2628
try:
@@ -29,7 +31,7 @@
2931
__all__.append('VertexAiRagMemoryService')
3032
except ImportError:
3133
logger.debug(
32-
'The Vertex sdk is not installed. If you want to use the'
34+
'The Vertex SDK is not installed. If you want to use the'
3335
' VertexAiRagMemoryService please install it. If not, you can ignore this'
3436
' warning.'
3537
)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import json
18+
import logging
19+
from typing import Optional
20+
from typing import TYPE_CHECKING
21+
22+
from typing_extensions import override
23+
24+
from google import genai
25+
26+
from .base_memory_service import BaseMemoryService
27+
from .base_memory_service import SearchMemoryResponse
28+
from .memory_entry import MemoryEntry
29+
30+
if TYPE_CHECKING:
31+
from ..sessions.session import Session
32+
33+
logger = logging.getLogger('google_adk.' + __name__)
34+
35+
36+
class VertexAiMemoryBankService(BaseMemoryService):
37+
"""Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
38+
39+
def __init__(
40+
self,
41+
project: Optional[str] = None,
42+
location: Optional[str] = None,
43+
agent_engine_id: Optional[str] = None,
44+
):
45+
"""Initializes a VertexAiMemoryBankService.
46+
47+
Args:
48+
project: The project ID of the Memory Bank to use.
49+
location: The location of the Memory Bank to use.
50+
agent_engine_id: The ID of the agent engine to use for the Memory Bank.
51+
e.g. '456' in
52+
'projects/my-project/locations/us-central1/reasoningEngines/456'.
53+
"""
54+
self._project = project
55+
self._location = location
56+
self._agent_engine_id = agent_engine_id
57+
58+
@override
59+
async def add_session_to_memory(self, session: Session):
60+
api_client = self._get_api_client()
61+
62+
if not self._agent_engine_id:
63+
raise ValueError('Agent Engine ID is required for Memory Bank.')
64+
65+
events = []
66+
for event in session.events:
67+
if event.content and event.content.parts:
68+
events.append({
69+
'content': event.content.model_dump(exclude_none=True, mode='json')
70+
})
71+
request_dict = {
72+
'direct_contents_source': {
73+
'events': events,
74+
},
75+
'scope': {
76+
'app_name': session.app_name,
77+
'user_id': session.user_id,
78+
},
79+
}
80+
81+
api_response = await api_client.async_request(
82+
http_method='POST',
83+
path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
84+
request_dict=request_dict,
85+
)
86+
logger.info(f'Generate memory response: {api_response}')
87+
88+
@override
89+
async def search_memory(self, *, app_name: str, user_id: str, query: str):
90+
api_client = self._get_api_client()
91+
92+
api_response = await api_client.async_request(
93+
http_method='POST',
94+
path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve',
95+
request_dict={
96+
'scope': {
97+
'app_name': app_name,
98+
'user_id': user_id,
99+
},
100+
'similarity_search_params': {
101+
'search_query': query,
102+
},
103+
},
104+
)
105+
api_response = _convert_api_response(api_response)
106+
logger.info(f'Search memory response: {api_response}')
107+
108+
if not api_response or not api_response.get('retrievedMemories', None):
109+
return SearchMemoryResponse()
110+
111+
memory_events = []
112+
for memory in api_response.get('retrievedMemories', []):
113+
# TODO: add more complex error handling
114+
memory_events.append(
115+
MemoryEntry(
116+
author='user',
117+
content=genai.types.Content(
118+
parts=[
119+
genai.types.Part(text=memory.get('memory').get('fact'))
120+
],
121+
role='user',
122+
),
123+
timestamp=memory.get('updateTime'),
124+
)
125+
)
126+
return SearchMemoryResponse(memories=memory_events)
127+
128+
def _get_api_client(self):
129+
"""Instantiates an API client for the given project and location.
130+
131+
It needs to be instantiated inside each request so that the event loop
132+
management can be properly propagated.
133+
134+
Returns:
135+
An API client for the given project and location.
136+
"""
137+
client = genai.Client(
138+
vertexai=True, project=self._project, location=self._location
139+
)
140+
return client._api_client
141+
142+
143+
def _convert_api_response(api_response):
144+
"""Converts the API response to a JSON object based on the type."""
145+
if hasattr(api_response, 'body'):
146+
return json.loads(api_response.body)
147+
return api_response
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
from typing import Any
17+
from unittest import mock
18+
19+
from google.adk.events import Event
20+
from google.adk.memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
21+
from google.adk.sessions import Session
22+
from google.genai import types
23+
import pytest
24+
25+
MOCK_APP_NAME = 'test-app'
26+
MOCK_USER_ID = 'test-user'
27+
28+
MOCK_SESSION = Session(
29+
app_name=MOCK_APP_NAME,
30+
user_id=MOCK_USER_ID,
31+
id='333',
32+
last_update_time=22333,
33+
events=[
34+
Event(
35+
id='444',
36+
invocation_id='123',
37+
author='user',
38+
timestamp=12345,
39+
content=types.Content(parts=[types.Part(text='test_content')]),
40+
),
41+
# Empty event, should be ignored
42+
Event(
43+
id='555',
44+
invocation_id='456',
45+
author='user',
46+
timestamp=12345,
47+
),
48+
],
49+
)
50+
51+
52+
RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$'
53+
GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$'
54+
55+
56+
class MockApiClient:
57+
"""Mocks the API Client."""
58+
59+
def __init__(self) -> None:
60+
"""Initializes MockClient."""
61+
self.async_request = mock.AsyncMock()
62+
self.async_request.side_effect = self._mock_async_request
63+
64+
async def _mock_async_request(
65+
self, http_method: str, path: str, request_dict: dict[str, Any]
66+
):
67+
"""Mocks the API Client request method."""
68+
if http_method == 'POST':
69+
if re.match(GENERATE_MEMORIES_REGEX, path):
70+
return {}
71+
elif re.match(RETRIEVE_MEMORIES_REGEX, path):
72+
if (
73+
request_dict.get('scope', None)
74+
and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME
75+
):
76+
return {
77+
'retrievedMemories': [
78+
{
79+
'memory': {
80+
'fact': 'test_content',
81+
},
82+
'updateTime': '2024-12-12T12:12:12.123456Z',
83+
},
84+
],
85+
}
86+
else:
87+
return {'retrievedMemories': []}
88+
else:
89+
raise ValueError(f'Unsupported path: {path}')
90+
else:
91+
raise ValueError(f'Unsupported http method: {http_method}')
92+
93+
94+
def mock_vertex_ai_memory_bank_service():
95+
"""Creates a mock Vertex AI Memory Bank service for testing."""
96+
return VertexAiMemoryBankService(
97+
project='test-project',
98+
location='test-location',
99+
agent_engine_id='123',
100+
)
101+
102+
103+
@pytest.fixture
104+
def mock_get_api_client():
105+
api_client = MockApiClient()
106+
with mock.patch(
107+
'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client',
108+
return_value=api_client,
109+
):
110+
yield api_client
111+
112+
113+
@pytest.mark.asyncio
114+
@pytest.mark.usefixtures('mock_get_api_client')
115+
async def test_add_session_to_memory(mock_get_api_client):
116+
memory_service = mock_vertex_ai_memory_bank_service()
117+
await memory_service.add_session_to_memory(MOCK_SESSION)
118+
119+
mock_get_api_client.async_request.assert_awaited_once_with(
120+
http_method='POST',
121+
path='reasoningEngines/123/memories:generate',
122+
request_dict={
123+
'direct_contents_source': {
124+
'events': [
125+
{
126+
'content': {
127+
'parts': [
128+
{'text': 'test_content'},
129+
],
130+
},
131+
},
132+
],
133+
},
134+
'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
135+
},
136+
)
137+
138+
139+
@pytest.mark.asyncio
140+
@pytest.mark.usefixtures('mock_get_api_client')
141+
async def test_search_memory(mock_get_api_client):
142+
memory_service = mock_vertex_ai_memory_bank_service()
143+
144+
result = await memory_service.search_memory(
145+
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
146+
)
147+
148+
mock_get_api_client.async_request.assert_awaited_once_with(
149+
http_method='POST',
150+
path='reasoningEngines/123/memories:retrieve',
151+
request_dict={
152+
'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
153+
'similarity_search_params': {'search_query': 'query'},
154+
},
155+
)
156+
157+
assert len(result.memories) == 1
158+
assert result.memories[0].content.parts[0].text == 'test_content'

0 commit comments

Comments
 (0)