Skip to content

Commit 92c3749

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: simplify toolset cleanup codes and extract common cleanup codes to utils which could be utilized by cli or client codes that directly call runners
PiperOrigin-RevId: 762463028
1 parent b9b2c3f commit 92c3749

File tree

6 files changed

+239
-406
lines changed

6 files changed

+239
-406
lines changed

src/google/adk/cli/fast_api.py

Lines changed: 3 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,9 @@
1616
import asyncio
1717
from contextlib import asynccontextmanager
1818
import importlib
19-
import inspect
20-
import json
2119
import logging
2220
import os
2321
from pathlib import Path
24-
import signal
2522
import sys
2623
import time
2724
import traceback
@@ -55,11 +52,9 @@
5552
from typing_extensions import override
5653

5754
from ..agents import RunConfig
58-
from ..agents.base_agent import BaseAgent
5955
from ..agents.live_request_queue import LiveRequest
6056
from ..agents.live_request_queue import LiveRequestQueue
6157
from ..agents.llm_agent import Agent
62-
from ..agents.llm_agent import LlmAgent
6358
from ..agents.run_config import StreamingMode
6459
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
6560
from ..evaluation.eval_case import EvalCase
@@ -75,12 +70,12 @@
7570
from ..sessions.vertex_ai_session_service import VertexAiSessionService
7671
from ..tools.base_toolset import BaseToolset
7772
from .cli_eval import EVAL_SESSION_ID_PREFIX
78-
from .cli_eval import EvalCaseResult
7973
from .cli_eval import EvalMetric
8074
from .cli_eval import EvalMetricResult
8175
from .cli_eval import EvalMetricResultPerInvocation
8276
from .cli_eval import EvalSetResult
8377
from .cli_eval import EvalStatus
78+
from .utils import cleanup
8479
from .utils import common
8580
from .utils import create_empty_state
8681
from .utils import envs
@@ -230,27 +225,8 @@ def get_fast_api_app(
230225

231226
trace.set_tracer_provider(provider)
232227

233-
toolsets_to_close: set[BaseToolset] = set()
234-
235228
@asynccontextmanager
236229
async def internal_lifespan(app: FastAPI):
237-
# Set up signal handlers for graceful shutdown
238-
original_sigterm = signal.getsignal(signal.SIGTERM)
239-
original_sigint = signal.getsignal(signal.SIGINT)
240-
241-
def cleanup_handler(sig, frame):
242-
# Log the signal
243-
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
244-
# Do synchronous cleanup if needed
245-
# Then call original handler if it exists
246-
if sig == signal.SIGTERM and callable(original_sigterm):
247-
original_sigterm(sig, frame)
248-
elif sig == signal.SIGINT and callable(original_sigint):
249-
original_sigint(sig, frame)
250-
251-
# Install cleanup handlers
252-
signal.signal(signal.SIGTERM, cleanup_handler)
253-
signal.signal(signal.SIGINT, cleanup_handler)
254230

255231
try:
256232
if lifespan:
@@ -259,46 +235,8 @@ def cleanup_handler(sig, frame):
259235
else:
260236
yield
261237
finally:
262-
# During shutdown, properly clean up all toolsets
263-
logger.info(
264-
"Server shutdown initiated, cleaning up %s toolsets",
265-
len(toolsets_to_close),
266-
)
267-
268-
# Create tasks for all toolset closures to run concurrently
269-
cleanup_tasks = []
270-
for toolset in toolsets_to_close:
271-
task = asyncio.create_task(close_toolset_safely(toolset))
272-
cleanup_tasks.append(task)
273-
274-
if cleanup_tasks:
275-
# Wait for all cleanup tasks with timeout
276-
done, pending = await asyncio.wait(
277-
cleanup_tasks,
278-
timeout=10.0, # 10 second timeout for cleanup
279-
return_when=asyncio.ALL_COMPLETED,
280-
)
281-
282-
# If any tasks are still pending, log it
283-
if pending:
284-
logger.warning(
285-
f"{len(pending)} toolset cleanup tasks didn't complete in time"
286-
)
287-
for task in pending:
288-
task.cancel()
289-
290-
# Restore original signal handlers
291-
signal.signal(signal.SIGTERM, original_sigterm)
292-
signal.signal(signal.SIGINT, original_sigint)
293-
294-
async def close_toolset_safely(toolset):
295-
"""Safely close a toolset with error handling."""
296-
try:
297-
logger.info(f"Closing toolset: {type(toolset).__name__}")
298-
await toolset.close()
299-
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
300-
except Exception as e:
301-
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
238+
# Create tasks for all runner closures to run concurrently
239+
await cleanup.close_runners(list(runner_dict.values()))
302240

303241
# Run the FastAPI server.
304242
app = FastAPI(lifespan=internal_lifespan)
@@ -903,16 +841,6 @@ async def process_messages():
903841
for task in pending:
904842
task.cancel()
905843

906-
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
907-
toolsets = set()
908-
if isinstance(agent, LlmAgent):
909-
for tool_union in agent.tools:
910-
if isinstance(tool_union, BaseToolset):
911-
toolsets.add(tool_union)
912-
for sub_agent in agent.sub_agents:
913-
toolsets.update(_get_all_toolsets(sub_agent))
914-
return toolsets
915-
916844
async def _get_root_agent_async(app_name: str) -> Agent:
917845
"""Returns the root agent for the given app."""
918846
if app_name in root_agent_dict:
@@ -924,7 +852,6 @@ async def _get_root_agent_async(app_name: str) -> Agent:
924852
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
925853

926854
root_agent_dict[app_name] = root_agent
927-
toolsets_to_close.update(_get_all_toolsets(root_agent))
928855
return root_agent
929856

930857
async def _get_runner_async(app_name: str) -> Runner:

src/google/adk/cli/utils/cleanup.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import logging
17+
from typing import List
18+
19+
from ...runners import Runner
20+
21+
logger = logging.getLogger("google_adk." + __name__)
22+
23+
24+
async def close_runners(runners: List[Runner]) -> None:
25+
cleanup_tasks = [asyncio.create_task(runner.close()) for runner in runners]
26+
if cleanup_tasks:
27+
# Wait for all cleanup tasks with timeout
28+
done, pending = await asyncio.wait(
29+
cleanup_tasks,
30+
timeout=30.0, # 30 second timeout for cleanup
31+
return_when=asyncio.ALL_COMPLETED,
32+
)
33+
34+
# If any tasks are still pending, log it
35+
if pending:
36+
logger.warning(
37+
"%s runner close tasks didn't complete in time", len(pending)
38+
)
39+
for task in pending:
40+
task.cancel()

src/google/adk/runners.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .sessions.in_memory_session_service import InMemorySessionService
4343
from .sessions.session import Session
4444
from .telemetry import tracer
45+
from .tools.base_toolset import BaseToolset
4546

4647
logger = logging.getLogger('google_adk.' + __name__)
4748

@@ -457,6 +458,37 @@ def _new_invocation_context_for_live(
457458
run_config=run_config,
458459
)
459460

461+
def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
462+
toolsets = set()
463+
if isinstance(agent, LlmAgent):
464+
for tool_union in agent.tools:
465+
if isinstance(tool_union, BaseToolset):
466+
toolsets.add(tool_union)
467+
for sub_agent in agent.sub_agents:
468+
toolsets.update(self._collect_toolset(sub_agent))
469+
return toolsets
470+
471+
async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]):
472+
"""Clean up toolsets with proper task context management."""
473+
if not toolsets_to_close:
474+
return
475+
476+
# This maintains the same task context throughout cleanup
477+
for toolset in toolsets_to_close:
478+
try:
479+
logger.info('Closing toolset: %s', type(toolset).__name__)
480+
# Use asyncio.wait_for to add timeout protection
481+
await asyncio.wait_for(toolset.close(), timeout=10.0)
482+
logger.info('Successfully closed toolset: %s', type(toolset).__name__)
483+
except asyncio.TimeoutError:
484+
logger.warning('Toolset %s cleanup timed out', type(toolset).__name__)
485+
except Exception as e:
486+
logger.error('Error closing toolset %s: %s', type(toolset).__name__, e)
487+
488+
async def close(self):
489+
"""Closes the runner."""
490+
await self._cleanup_toolsets(self._collect_toolset(self.agent))
491+
460492

461493
class InMemoryRunner(Runner):
462494
"""An in-memory Runner for testing and development.

0 commit comments

Comments
 (0)