8000 feat(utils): visualize agent workflows in Mermaid by d33bs · Pull Request #1441 · google/adk-python · GitHub
[go: up one dir, main page]

Skip to content

feat(utils): visualize agent workflows in Mermaid #1441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions src/google/adk/utils/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Utilities for visualizing google-adk agents.
"""

from __future__ import annotations

import itertools
from typing import Any
from typing import Tuple

from google.adk.agents import LoopAgent
from google.adk.agents import ParallelAgent
from google.adk.agents import SequentialAgent
import requests


def build_mermaid(root_agent: Any) -> Tuple[str, bytes]:
"""
Generates a Mermaid 'flowchart LR' diagram for a google-adk
agent tree and returns both the Mermaid source and a PNG
image rendered via the Kroki API.

Args:
root_agent (Any):
The root agent node of the google-adk agent tree.
This should be an instance
of SequentialAgent, LoopAgent, ParallelAgent,
or a compatible agent class with a
`name` attribute and an optional `sub_agents`
attribute.

Returns:
Tuple[str, bytes]:
A tuple containing:
- The Mermaid source code as a string.
- The PNG image bytes rendered from the Mermaid diagram.

Raises:
requests.RequestException: If the request to the Kroki API fails.

Example:
>>> mermaid_src, png_bytes = build_mermaid(my_agent_tree)
>>> print(mermaid_src)
>>> with open("diagram.png", "wb") as f:
... f.write(png_bytes)
"""
clusters, edges = [], []
first_of, last_of, nodes = {}, {}, {}

# Walk the agent tree
def walk(node):
nid = id(node)
nodes[nid] = node
name = node.name
subs = getattr(node, "sub_agents", []) or []
if subs:
first_of[nid], last_of[nid] = subs[0].name, subs[-1].name
# Create subgraph for non-root composite nodes
if node is not root_agent and isinstance(
node, (SequentialAgent, LoopAgent, ParallelAgent)
):
block = [f'subgraph {name}["{name}"]']
if isinstance(node, (SequentialAgent, LoopAgent)):
for a, b in itertools.pairwise(subs):
block.append(f" {a.name} --> {b.name}")
# loop-back even for single-child loops
if isinstance(node, LoopAgent):
if len(subs) == 1:
block.append(f" {subs[0].name} -.->|repeat| {subs[0].name}")
elif len(subs) > 1:
block.append(f" {subs[-1].name} -.->|repeat| {subs[0].name}")
elif isinstance(node, ParallelAgent):
for child in subs:
block.append(f' {child.name}["{child.name}"]')
block.append("end")
clusters.append("\n".join(block))
# Recurse
for child in subs:
walk(child)

walk(root_agent)

# Link root children
if isinstance(root_agent, SequentialAgent):
children = root_agent.sub_agents or []
# Kick-off
if children:
first = children[0]
if isinstance(first, ParallelAgent):
for c in first.sub_agents:
edges.append(f"{root_agent.name} -.-> {c.name}")
else:
edges.append(
f"{root_agent.name} -.-> {first_of.get(id(first), first.name)}"
)
# Chain
for prev, nxt in itertools.pairwise(children):
prev_exits = (
[c.name for c in prev.sub_agents]
if isinstance(prev, ParallelAgent)
else [last_of.get(id(prev), prev.name)]
)
nxt_entries = (
[c.name for c in nxt.sub_agents]
if isinstance(nxt, ParallelAgent)
else [first_of.get(id(nxt), nxt.name)]
)
arrow = "-.->" if isinstance(nxt, ParallelAgent) else "-->"
for src in prev_exits:
for dst in nxt_entries:
edges.append(f"{src} {arrow} {dst}")
else:
for c in getattr(root_agent, "sub_agents", []) or []:
edges.append(f"{root_agent.name} --> {c.name}")

# Assemble graph as mermaid code
mermaid_src = "\n".join(
["flowchart LR", f'{root_agent.name}["{root_agent.name}"]']
+ clusters
+ edges
)

# Render via Kroki
# note: kroki is a third party service which enables the rendering
# of mermaid diagrams without local npm installation of mermaid.
png = requests.post(
"https://kroki.io/mermaid/png",
data=mermaid_src.encode("utf-8"),
headers={"Content-Type": "text/plain"},
).content

return mermaid_src, png
118 changes: 118 additions & 0 deletions tests/unittests/utils/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Tests for src/google/adk/utils/visualization.py
"""

from __future__ import annotations

from google.adk.agents import Agent
from google.adk.agents import LoopAgent
from google.adk.agents import ParallelAgent
from google.adk.agents import SequentialAgent
from google.adk.utils.visualization import build_mermaid


def test_build_mermaid():
"""
Tests build_mermaid function.

We build an agent workflow and then pass the
root_agent to build_mermaid, which builds
a mermaid diagram.
"""

agent1 = Agent(
model="nonexistent-model",
name="agent1",
description="Example",
instruction="""
Example
""",
)

agent2 = Agent(
model="nonexistent-model",
name="agent2",
description="Example",
instruction="""
Example
""",
)

agent3 = Agent(
model="nonexistent-model",
name="agent3",
description="Example",
instruction=f"""
Example
""",
)

agent4 = Agent(
model="nonexistent-model",
name="agent4",
description="Example",
instruction=f"""
Example
""",
)

agent5 = Agent(
model="nonexistent-model",
name="agent5",
description="Example",
instruction=f"""
Example
""",
)

agent6 = Agent(
model="nonexistent-model",
name="agent6",
description="Example",
instruction=f"""
Example
""",
)

agent7 = Agent(
model="nonexistent-model",
name="agent7",
description="Example",
instruction=f"""
Example
""",
)

# example sequence
sequence_1 = SequentialAgent(
name="ExampleSequence",
sub_agents=[agent1, agent2],
)

# example loop
loop_1 = LoopAgent(
name="ExampleLoop",
sub_agents=[agent6, agent7],
max_iterations=10,
)

# example parallel
parallel_1 = ParallelAgent(
name="ExampleParallel",
sub_agents=[agent3, agent4, agent5],
)

# sequence for orchestrating everything together
root_agent = SequentialAgent(
name="root_agent",
sub_agents=[sequence_1, loop_1, parallel_1],
description="Example",
)

mermaid_src, png_display_bytes = build_mermaid(root_agent)

assert isinstance(mermaid_src, str)
assert mermaid_src.startswith("flowchart LR")

assert isinstance(png_display_bytes, bytes)
assert png_display_bytes.startswith(b"\x89PNG\r\n\x1a\n")
0