Skip to content

Commit be0786e

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: rename agent_dir to agents_dir and rename app_id to app_name in fast_api.py to make it consistent among every endpoints
PiperOrigin-RevId: 763483339
1 parent 6b89ceb commit be0786e

File tree

6 files changed

+32
-32
lines changed

6 files changed

+32
-32
lines changed

src/google/adk/cli/cli_tools_click.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ async def _collect_eval_results() -> list[EvalCaseResult]:
355355

356356
# Write eval set results.
357357
local_eval_set_results_manager = LocalEvalSetResultsManager(
358-
agent_dir=os.path.dirname(agent_module_file_path)
358+
agents_dir=os.path.dirname(agent_module_file_path)
359359
)
360360
eval_set_id_to_eval_results = collections.defaultdict(list)
361361
for eval_case_result in eval_results:
@@ -500,7 +500,7 @@ async def _lifespan(app: FastAPI):
500500
)
501501

502502
app = get_fast_api_app(
503-
agent_dir=agents_dir,
503+
agents_dir=agents_dir,
504504
session_db_url=session_db_url,
505505
allow_origins=allow_origins,
506506
web=True,
@@ -601,7 +601,7 @@ def cli_api_server(
601601

602602
config = uvicorn.Config(
603603
get_fast_api_app(
604-
agent_dir=agents_dir,
604+
agents_dir=agents_dir,
605605
session_db_url=session_db_url,
606606
allow_origins=allow_origins,
607607
web=False,

src/google/adk/cli/fast_api.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class GetEventGraphResult(common.BaseModel):
191191

192192
def get_fast_api_app(
193193
*,
194-
agent_dir: str,
194+
agents_dir: str,
195195
session_db_url: str = "",
196196
allow_origins: Optional[list[str]] = None,
197197
web: bool,
@@ -210,7 +210,7 @@ def get_fast_api_app(
210210
memory_exporter = InMemoryExporter(session_trace_dict)
211211
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
212212
if trace_to_cloud:
213-
envs.load_dotenv_for_agent("", agent_dir)
213+
envs.load_dotenv_for_agent("", agents_dir)
214214
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
215215
processor = export.BatchSpanProcessor(
216216
CloudTraceSpanExporter(project_id=project_id)
@@ -249,8 +249,8 @@ async def internal_lifespan(app: FastAPI):
249249
allow_headers=["*"],
250250
)
251251

252-
if agent_dir not in sys.path:
253-
sys.path.append(agent_dir)
252+
if agents_dir not in sys.path:
253+
sys.path.append(agents_dir)
254254

255255
runner_dict = {}
256256
root_agent_dict = {}
@@ -259,8 +259,8 @@ async def internal_lifespan(app: FastAPI):
259259
artifact_service = InMemoryArtifactService()
260260
memory_service = InMemoryMemoryService()
261261

262-
eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir)
263-
eval_set_results_manager = LocalEvalSetResultsManager(agent_dir=agent_dir)
262+
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
263+
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
264264

265265
# Build the Session service
266266
agent_engine_id = ""
@@ -270,7 +270,7 @@ async def internal_lifespan(app: FastAPI):
270270
agent_engine_id = session_db_url.split("://")[1]
271271
if not agent_engine_id:
272272
raise click.ClickException("Agent engine id can not be empty.")
273-
envs.load_dotenv_for_agent("", agent_dir)
273+
envs.load_dotenv_for_agent("", agents_dir)
274274
session_service = VertexAiSessionService(
275275
os.environ["GOOGLE_CLOUD_PROJECT"],
276276
os.environ["GOOGLE_CLOUD_LOCATION"],
@@ -282,7 +282,7 @@ async def internal_lifespan(app: FastAPI):
282282

283283
@app.get("/list-apps")
284284
def list_apps() -> list[str]:
285-
base_path = Path.cwd() / agent_dir
285+
base_path = Path.cwd() / agents_dir
286286
if not base_path.exists():
287287
raise HTTPException(status_code=404, detail="Path not found")
288288
if not base_path.is_dir():
@@ -398,9 +398,9 @@ async def create_session(
398398
app_name=app_name, user_id=user_id, state=state
399399
)
400400

401-
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
401+
def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str:
402402
return os.path.join(
403-
agent_dir,
403+
agents_dir,
404404
app_name,
405405
eval_set_id + _EVAL_SET_FILE_EXTENSION,
406406
)
@@ -490,7 +490,7 @@ async def run_eval(
490490

491491
# Create a mapping from eval set file to all the evals that needed to be
492492
# run.
493-
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
493+
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
494494

495495
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
496496

@@ -663,9 +663,9 @@ async def delete_artifact(
663663
@app.post("/run", response_model_exclude_none=True)
664664
async def agent_run(req: AgentRunRequest) -> list[Event]:
665665
# Connect to managed session if agent_engine_id is set.
666-
app_id = agent_engine_id if agent_engine_id else req.app_name
666+
app_name = agent_engine_id if agent_engine_id else req.app_name
667667
session = await session_service.get_session(
668-
app_name=app_id, user_id=req.user_id, session_id=req.session_id
668+
app_name=app_name, user_id=req.user_id, session_id=req.session_id
669669
)
670670
if not session:
671671
raise HTTPException(status_code=404, detail="Session not found")
@@ -684,10 +684,10 @@ async def agent_run(req: AgentRunRequest) -> list[Event]:
684684
@app.post("/run_sse")
685685
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
686686
# Connect to managed session if agent_engine_id is set.
687-
app_id = agent_engine_id if agent_engine_id else req.app_name
687+
app_name = agent_engine_id if agent_engine_id else req.app_name
688688
# SSE endpoint
689689
session = await session_service.get_session(
690-
app_name=app_id, user_id=req.user_id, session_id=req.session_id
690+
app_name=app_name, user_id=req.user_id, session_id=req.session_id
691691
)
692692
if not session:
693693
raise HTTPException(status_code=404, detail="Session not found")
@@ -726,9 +726,9 @@ async def get_event_graph(
726726
app_name: str, user_id: str, session_id: str, event_id: str
727727
):
728728
# Connect to managed session if agent_engine_id is set.
729-
app_id = agent_engine_id if agent_engine_id else app_name
729+
app_name = agent_engine_id if agent_engine_id else app_name
730730
session = await session_service.get_session(
731-
app_name=app_id, user_id=user_id, session_id=session_id
731+
app_name=app_name, user_id=user_id, session_id=session_id
732732
)
733733
session_events = session.events if session else []
734734
event = next((x for x in session_events if x.id == event_id), None)
@@ -783,9 +783,9 @@ async def agent_live_run(
783783
await websocket.accept()
784784

785785
# Connect to managed session if agent_engine_id is set.
786-
app_id = agent_engine_id if agent_engine_id else app_name
786+
app_name = agent_engine_id if agent_engine_id else app_name
787787
session = await session_service.get_session(
788-
app_name=app_id, user_id=user_id, session_id=session_id
788+
app_name=app_name, user_id=user_id, session_id=session_id
789789
)
790790
if not session:
791791
# Accept first so that the client is aware of connection establishment,
@@ -855,7 +855,7 @@ async def _get_root_agent_async(app_name: str) -> Agent:
855855

856856
async def _get_runner_async(app_name: str) -> Runner:
857857
"""Returns the runner for the given app."""
858-
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
858+
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
859859
if app_name in runner_dict:
860860
return runner_dict[app_name]
861861
root_agent = await _get_root_agent_async(app_name)

src/google/adk/evaluation/local_eval_set_results_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def _sanitize_eval_set_result_name(eval_set_result_name: str) -> str:
3636
class LocalEvalSetResultsManager(EvalSetResultsManager):
3737
"""An EvalSetResult manager that stores eval set results locally on disk."""
3838

39-
def __init__(self, agent_dir: str):
40-
self._agent_dir = agent_dir
39+
def __init__(self, agents_dir: str):
40+
self._agents_dir = agents_dir
4141

4242
@override
4343
def save_eval_set_result(
@@ -108,4 +108,4 @@ def list_eval_set_results(self, app_name: str) -> list[str]:
108108
return eval_result_files
109109

110110
def _get_eval_history_dir(self, app_name: str) -> str:
111-
return os.path.join(self._agent_dir, app_name, _ADK_EVAL_HISTORY_DIR)
111+
return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR)

src/google/adk/evaluation/local_eval_sets_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def load_eval_set_from_file(
182182
class LocalEvalSetsManager(EvalSetsManager):
183183
"""An EvalSets manager that stores eval sets locally on disk."""
184184

185-
def __init__(self, agent_dir: str):
186-
self._agent_dir = agent_dir
185+
def __init__(self, agents_dir: str):
186+
self._agents_dir = agents_dir
187187

188188
@override
189189
def get_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
@@ -216,7 +216,7 @@ def create_eval_set(self, app_name: str, eval_set_id: str):
216216
@override
217217
def list_eval_sets(self, app_name: str) -> list[str]:
218218
"""Returns a list of EvalSets that belong to the given app_name."""
219-
eval_set_file_path = os.path.join(self._agent_dir, app_name)
219+
eval_set_file_path = os.path.join(self._agents_dir, app_name)
220220
eval_sets = []
221221
for file in os.listdir(eval_set_file_path):
222222
if file.endswith(_EVAL_SET_FILE_EXTENSION):
@@ -247,7 +247,7 @@ def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: EvalCase):
247247

248248
def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
249249
return os.path.join(
250-
self._agent_dir,
250+
self._agents_dir,
251251
app_name,
252252
eval_set_id + _EVAL_SET_FILE_EXTENSION,
253253
)

tests/unittests/fast_api/test_fast_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def test_app(mock_session_service, mock_artifact_service, mock_memory_service):
308308
):
309309
# Get the FastAPI app, but don't actually run it
310310
app = get_fast_api_app(
311-
agent_dir=".", web=True, session_db_url="", allow_origins=["*"]
311+
agents_dir=".", web=True, session_db_url="", allow_origins=["*"]
312312
)
313313

314314
# Create a TestClient that doesn't start a real server

tests/unittests/testing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ async def run_async(self, new_message: types.ContentUnion) -> list[Event]:
202202
session_id=self.session.id,
203203
new_message=get_user_content(new_message),
204204
):
205-
events.append(event)
205+
events.append(event)
206206
return events
207207

208208
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:

0 commit comments

Comments
 (0)