Skip to content

Commit 3b1f2ae

Browse files
selcukguncopybara-github
authored andcommitted
fix: Fix broken agent graphs
Fixes #973, #1170 PiperOrigin-RevId: 767921051
1 parent 6ed6351 commit 3b1f2ae

File tree

1 file changed

+37
-35
lines changed

1 file changed

+37
-35
lines changed

src/google/adk/cli/agent_graph.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,11 @@ def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]):
6464
if isinstance(tool_or_agent, BaseAgent):
6565
# Added Workflow Agent checks for different agent types
6666
if isinstance(tool_or_agent, SequentialAgent):
67-
return tool_or_agent.name + f' (Sequential Agent)'
67+
return tool_or_agent.name + ' (Sequential Agent)'
6868
elif isinstance(tool_or_agent, LoopAgent):
69-
return tool_or_agent.name + f' (Loop Agent)'
69+
return tool_or_agent.name + ' (Loop Agent)'
7070
elif isinstance(tool_or_agent, ParallelAgent):
71-
return tool_or_agent.name + f' (Parallel Agent)'
71+
return tool_or_agent.name + ' (Parallel Agent)'
7272
else:
7373
return tool_or_agent.name
7474
elif isinstance(tool_or_agent, BaseTool):
@@ -144,49 +144,53 @@ def should_build_agent_cluster(tool_or_agent: Union[BaseAgent, BaseTool]):
144144
)
145145
return False
146146

147-
def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str):
148-
if isinstance(agent, LoopAgent) and parent_agent:
147+
async def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str):
148+
if isinstance(agent, LoopAgent):
149149
# Draw the edge from the parent agent to the first sub-agent
150-
draw_edge(parent_agent.name, agent.sub_agents[0].name)
150+
if parent_agent:
151+
draw_edge(parent_agent.name, agent.sub_agents[0].name)
151152
length = len(agent.sub_agents)
152-
currLength = 0
153+
curr_length = 0
153154
# Draw the edges between the sub-agents
154155
for sub_agent_int_sequential in agent.sub_agents:
155-
build_graph(child, sub_agent_int_sequential, highlight_pairs)
156+
await build_graph(child, sub_agent_int_sequential, highlight_pairs)
156157
# Draw the edge between the current sub-agent and the next one
157158
# If it's the last sub-agent, draw an edge to the first one to indicating a loop
158159
draw_edge(
159-
agent.sub_agents[currLength].name,
160+
agent.sub_agents[curr_length].name,
160161
agent.sub_agents[
161-
0 if currLength == length - 1 else currLength + 1
162+
0 if curr_length == length - 1 else curr_length + 1
162163
].name,
163164
)
164-
currLength += 1
165-
elif isinstance(agent, SequentialAgent) and parent_agent:
165+
curr_length += 1
166+
elif isinstance(agent, SequentialAgent):
166167
# Draw the edge from the parent agent to the first sub-agent
167-
draw_edge(parent_agent.name, agent.sub_agents[0].name)
168+
if parent_agent:
169+
draw_edge(parent_agent.name, agent.sub_agents[0].name)
168170
length = len(agent.sub_agents)
169-
currLength = 0
171+
curr_length = 0
170172

171173
# Draw the edges between the sub-agents
172174
for sub_agent_int_sequential in agent.sub_agents:
173-
build_graph(child, sub_agent_int_sequential, highlight_pairs)
175+
await build_graph(child, sub_agent_int_sequential, highlight_pairs)
174176
# Draw the edge between the current sub-agent and the next one
175177
# If it's the last sub-agent, don't draw an edge to avoid a loop
176-
draw_edge(
177-
agent.sub_agents[currLength].name,
178-
agent.sub_agents[currLength + 1].name,
179-
) if currLength != length - 1 else None
180-
currLength += 1
178+
if curr_length != length - 1:
179+
draw_edge(
180+
agent.sub_agents[curr_length].name,
181+
agent.sub_agents[curr_length + 1].name,
182+
)
183+
curr_length += 1
181184

182-
elif isinstance(agent, ParallelAgent) and parent_agent:
185+
elif isinstance(agent, ParallelAgent):
183186
# Draw the edge from the parent agent to every sub-agent
184187
for sub_agent in agent.sub_agents:
185-
build_graph(child, sub_agent, highlight_pairs)
186-
draw_edge(parent_agent.name, sub_agent.name)
188+
await build_graph(child, sub_agent, highlight_pairs)
189+
if parent_agent:
190+
draw_edge(parent_agent.name, sub_agent.name)
187191
else:
188192
for sub_agent in agent.sub_agents:
189-
build_graph(child, sub_agent, highlight_pairs)
193+
await build_graph(child, sub_agent, highlight_pairs)
190194
draw_edge(agent.name, sub_agent.name)
191195

192196
child.attr(
@@ -196,21 +200,20 @@ def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str):
196200
fontcolor=light_gray,
197201
)
198202

199-
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
203+
async def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
200204
name = get_node_name(tool_or_agent)
201205
shape = get_node_shape(tool_or_agent)
202206
caption = get_node_caption(tool_or_agent)
203-
asCluster = should_build_agent_cluster(tool_or_agent)
204-
child = None
207+
as_cluster = should_build_agent_cluster(tool_or_agent)
205208
if highlight_pairs:
206209
for highlight_tuple in highlight_pairs:
207210
if name in highlight_tuple:
208211
# if in highlight, draw highlight node
209-
if asCluster:
212+
if as_cluster:
210213
cluster = graphviz.Digraph(
211214
name='cluster_' + name
212215
) # adding "cluster_" to the name makes the graph render as a cluster subgraph
213-
build_cluster(cluster, agent, name)
216+
await build_cluster(cluster, agent, name)
214217
graph.subgraph(cluster)
215218
else:
216219
graph.node(
@@ -224,12 +227,12 @@ def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
224227
)
225228
return
226229
# if not in highlight, draw non-highlight node
227-
if asCluster:
230+
if as_cluster:
228231

229232
cluster = graphviz.Digraph(
230233
name='cluster_' + name
231234
) # adding "cluster_" to the name makes the graph render as a cluster subgraph
232-
build_cluster(cluster, agent, name)
235+
await build_cluster(cluster, agent, name)
233236
graph.subgraph(cluster)
234237

235238
else:
@@ -264,10 +267,9 @@ def draw_edge(from_name, to_name):
264267
else:
265268
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
266269

267-
draw_node(agent)
270+
await draw_node(agent)
268271
for sub_agent in agent.sub_agents:
269-
270-
build_graph(graph, sub_agent, highlight_pairs, agent)
272+
await build_graph(graph, sub_agent, highlight_pairs, agent)
271273
if not should_build_agent_cluster(
272274
sub_agent
273275
) and not should_build_agent_cluster(
@@ -276,7 +278,7 @@ def draw_edge(from_name, to_name):
276278
draw_edge(agent.name, sub_agent.name)
277279
if isinstance(agent, LlmAgent):
278280
for tool in await agent.canonical_tools():
279-
draw_node(tool)
281+
await draw_node(tool)
280282
draw_edge(agent.name, get_node_name(tool))
281283

282284

0 commit comments

Comments
 (0)