8000 feat: add ConditionalAgent for routing between two sub-agents by LaamiriOuail · Pull Request #1572 · google/adk-python · GitHub
[go: up one dir, main page]

Skip to content

feat: add ConditionalAgent for routing between two sub-agents #1572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions contributing/samples/simple_conditional_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Sample demonstrating `ConditionalAgent` routing."""
from . import agent # noqa: F401
111 changes: 111 additions & 0 deletions contributing/samples/simple_conditional_agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""A minimal example demonstrating ConditionalAgent.

This sample shows how to route user requests to different sub-agents based on a
predicate evaluated against the current `InvocationContext`.

• If the user asks to *roll* a die → `roll_agent` is triggered.
• Otherwise, the request is delegated to `prime_agent` for prime-number checks.
"""
from __future__ import annotations

import random
from typing import cast

from google.adk.agents.conditional_agent import ConditionalAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.genai import types

# -----------------------------------------------------------------------------
# Helper tool functions
# -----------------------------------------------------------------------------

def roll_die(sides: int) -> int:
"""Roll a die with a given number of *sides* and return the result."""
return random.randint(1, sides)


def check_prime(nums: list[int]) -> str:
"""Return a formatted string indicating which numbers are prime."""
primes = []
for n in nums:
n = int(n)
if n <= 1:
continue
for i in range(2, int(n ** 0.5) + 1):
if n % i == 0:
break
else:
primes.append(n)
return "No prime numbers found." if not primes else ", ".join(map(str, primes)) + " are prime numbers."

# -----------------------------------------------------------------------------
# Sub-agents definitions
# -----------------------------------------------------------------------------

roll_agent = LlmAgent(
name="roll_agent",
description="Handles rolling dice of different sizes.",
model="gemini-2.0-flash",
instruction=(
"""
You are responsible for rolling dice based on the user's request.\n
When asked to roll a die, call the `roll_die` tool with the number of
sides. Do **not** decide the outcome yourself – always use the tool.
"""
),
tools=[roll_die],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)

prime_agent = LlmAgent(
name="prime_agent",
description="Checks whether provided numbers are prime.",
model="gemini-2.0-flash",
instruction=(
"""
You determine if numbers are prime.\n
Whenever the user asks about prime numbers, call the `check_prime` tool
with a list of integers and return its result. Never attempt to compute
primes manually – always rely on the tool.
"""
),
tools=[check_prime],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting(
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)

# -----------------------------------------------------------------------------
# Predicate used by the ConditionalAgent
# -----------------------------------------------------------------------------

def is_roll_request(ctx: InvocationContext) -> bool:
"""Return True if the last user message seems to be a *roll* request."""
if not ctx.user_content or not ctx.user_content.parts:
return False
text = cast(str, getattr(ctx.user_content.parts[0], "text", "")).lower()
return "roll" in text

# -----------------------------------------------------------------------------
# Root ConditionalAgent
# -----------------------------------------------------------------------------

root_agent = ConditionalAgent(
name="simple_conditional_agent",
description="Routes to roll or prime agent based on user's intent.",
sub_agents=[roll_agent, prime_agent],
condition=is_roll_request,
)
2 changes: 2 additions & 0 deletions src/google/adk/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .llm_agent import LlmAgent
from .loop_agent import LoopAgent
from .parallel_agent import ParallelAgent
from .conditional_agent import ConditionalAgent
from .run_config import RunConfig
from .sequential_agent import SequentialAgent

Expand All @@ -28,5 +29,6 @@
'LlmAgent',
'LoopAgent',
'ParallelAgent',
'ConditionalAgent',
'SequentialAgent',
]
89 changes: 89 additions & 0 deletions src/google/adk/agents/conditional_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Conditional agent implementation.

This agent evaluates a provided predicate (condition) against the current
`InvocationContext` and delegates the call to one of *exactly two* sub-agents
based on the result:

* ``condition(ctx) == True`` -> ``sub_agents[0]`` ("true" agent)
* ``condition(ctx) == False`` -> ``sub_agents[1]`` ("false" agent)

Typical usages include simple routing / decision making, e.g. directing the
conversation to either a *search* agent or *calculation* agent depending on the
user request type.

Example:

```python
router = ConditionalAgent(
name="router",
description="Route to calc or search agent based on context flag",
sub_agents=[calc_agent, search_agent],
condition=lambda ctx: ctx.state.get("needs_calc", False),
)
```

The condition can be either synchronous or asynchronous. If the predicate is
awaitable (i.e. returns an ``Awaitable[bool]``), it will be awaited.
"""
from __future__ import annotations

import inspect
from typing import Awaitable, Callable, AsyncGenerator, Union

from typing_extensions import override

from ..events.event import Event
from .base_agent import BaseAgent
from .invocation_context import InvocationContext

# Type alias for the predicate
Condition = Callable[[InvocationContext], Union[bool, Awaitable[bool]]]


class ConditionalAgent(BaseAgent):
"""A shell agent that chooses between two sub-agents based on a condition."""

# NOTE: The predicate function itself is **not** serialisable; exclude from
# model dump to avoid pydantic complaining when exporting to json.
condition: Condition # type: ignore[assignment]

# ---------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------
async def _evaluate_condition(self, ctx: InvocationContext) -> bool:
"""Evaluates *self.condition* and returns the boolean result."""
result = self.condition(ctx)
if inspect.isawaitable(result):
result = await result # type: ignore[assignment]
return bool(result)

def _validate_sub_agent_count(self) -> None:
if len(self.sub_agents) != 2:
raise ValueError(
"ConditionalAgent requires *exactly* two sub-agents (true / false)."
)

# ---------------------------------------------------------------------
# Core execution implementations
# ---------------------------------------------------------------------
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Delegates to either ``sub_agents[0]`` or ``sub_agents[1]``."""
self._validate_sub_agent_count()
chosen_agent_idx = 0 if await self._evaluate_condition(ctx) else 1
chosen_agent = self.sub_agents[chosen_agent_idx]
async for event in chosen_agent.run_async(ctx):
yield event

@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Live mode implementation mirroring the async version."""
self._validate_sub_agent_count()
chosen_agent_idx = 0 if await self._evaluate_condition(ctx) else 1
chosen_agent = self.sub_agents[chosen_agent_idx]
async for event in chosen_agent.run_live(ctx):
yield event
133 changes: 133 additions & 0 deletions tests/unittests/agents/test_conditional_agent.py
9E7A
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Tests for the ConditionalAgent."""

from __future__ import annotations

import asyncio
from typing import AsyncGenerator

import pytest
from typing_extensions import override

from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.conditional_agent import ConditionalAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.events import Event
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types


class _TestingAgent(BaseAgent):
"""A simple testing agent that emits a single event."""

@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
author=self.name,
invocation_id=ctx.invocation_id,
content=types.Content(parts=[types.Part(text=f"Hello, async {self.name}!")]),
)

@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
author=self.name,
invocation_id=ctx.invocation_id,
content=types.Content(parts=[types.Part(text=f"Hello, live {self.name}!")]),
)


async def _create_parent_invocation_context(
test_name: str, agent: BaseAgent
) -> InvocationContext:
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name="test_app", user_id="test_user"
)
return InvocationContext(
invocation_id=f"{test_name}_invocation_id",
agent=agent,
session=session,
session_service=session_service,
)


# ---------------------------------------------------------------------------
# Async tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
@pytest.mark.parametrize("predicate,expected_idx", [(lambda _ctx: True, 0), (lambda _ctx: False, 1)])
async def test_run_async_branching(request: pytest.FixtureRequest, predicate, expected_idx):
agents = [
_TestingAgent(name=f"{request.function.__name__}_agent_true"),
_TestingAgent(name=f"{request.function.__name__}_agent_false"),
]
conditional_agent = ConditionalAgent(
name=f"{request.function.__name__}_conditional_agent",
sub_agents=agents,
condition=predicate,
)
parent_ctx = await _create_parent_invocation_context(request.function.__name__, conditional_agent)
events = [e async for e in conditional_agent.run_async(parent_ctx)]

# Only the chosen agent should produce events
assert len(events) == 1
assert events[0].author == agents[expected_idx].name
assert events[0].content.parts[0].text == f"Hello, async {agents[expected_idx].name}!"


@pytest.mark.asyncio
async def test_run_async_async_predicate(request: pytest.FixtureRequest):
async def async_predicate(_ctx):
await asyncio.sleep(0.01)
return True

agent_true = _TestingAgent(name=f"{request.function.__name__}_agent_true")
agent_false = _TestingAgent(name=f"{request.function.__name__}_agent_false")

conditional_agent = ConditionalAgent(
name=f"{request.function.__name__}_conditional_agent",
sub_agents=[agent_true, agent_false],
condition=async_predicate,
)
parent_ctx = await _create_parent_invocation_context(request.function.__name__, conditional_agent)
events = [e async for e in conditional_agent.run_async(parent_ctx)]

assert len(events) == 1 and events[0].author == agent_true.name


# ---------------------------------------------------------------------------
# Live tests
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_run_live_branching(request: pytest.FixtureRequest):
agent_true = _TestingAgent(name=f"{request.function.__name__}_agent_true")
agent_false = _TestingAgent(name=f"{request.function.__name__}_agent_false")
conditional_agent = ConditionalAgent(
name=f"{request.function.__name__}_conditional_agent",
sub_agents=[agent_true, agent_false],
condition=lambda _ctx: False,
)
parent_ctx = await _create_parent_invocation_context(request.function.__name__, conditional_agent)
events = [e async for e in conditional_agent.run_live(parent_ctx)]

assert len(events) == 1
assert events[0].author == agent_false.name
assert events[0].content.parts[0].text == f"Hello, live {agent_false.name}!"


# ---------------------------------------------------------------------------
# Validation tests
# ---------------------------------------------------------------------------


def test_invalid_sub_agent_count():
with pytest.raises(ValueError):
ConditionalAgent(name="invalid", sub_agents=[], condition=lambda _ctx: True)
0