diff --git a/contributing/samples/simple_conditional_agent/__init__.py b/contributing/samples/simple_conditional_agent/__init__.py new file mode 100644 index 000000000..ee9ac11be --- /dev/null +++ b/contributing/samples/simple_conditional_agent/__init__.py @@ -0,0 +1,2 @@ +"""Sample demonstrating `ConditionalAgent` routing.""" +from . import agent # noqa: F401 diff --git a/contributing/samples/simple_conditional_agent/agent.py b/contributing/samples/simple_conditional_agent/agent.py new file mode 100644 index 000000000..15a3a40a8 --- /dev/null +++ b/contributing/samples/simple_conditional_agent/agent.py @@ -0,0 +1,111 @@ +"""A minimal example demonstrating ConditionalAgent. + +This sample shows how to route user requests to different sub-agents based on a +predicate evaluated against the current `InvocationContext`. + +• If the user asks to *roll* a die → `roll_agent` is triggered. +• Otherwise, the request is delegated to `prime_agent` for prime-number checks. +""" +from __future__ import annotations + +import random +from typing import cast + +from google.adk.agents.conditional_agent import ConditionalAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.genai import types + +# ----------------------------------------------------------------------------- +# Helper tool functions +# ----------------------------------------------------------------------------- + +def roll_die(sides: int) -> int: + """Roll a die with a given number of *sides* and return the result.""" + return random.randint(1, sides) + + +def check_prime(nums: list[int]) -> str: + """Return a formatted string indicating which numbers are prime.""" + primes = [] + for n in nums: + n = int(n) + if n <= 1: + continue + for i in range(2, int(n ** 0.5) + 1): + if n % i == 0: + break + else: + primes.append(n) + return "No prime numbers found." if not primes else ", ".join(map(str, primes)) + " are prime numbers." + +# ----------------------------------------------------------------------------- +# Sub-agents definitions +# ----------------------------------------------------------------------------- + +roll_agent = LlmAgent( + name="roll_agent", + description="Handles rolling dice of different sizes.", + model="gemini-2.0-flash", + instruction=( + """ + You are responsible for rolling dice based on the user's request.\n + When asked to roll a die, call the `roll_die` tool with the number of + sides. Do **not** decide the outcome yourself – always use the tool. + """ + ), + tools=[roll_die], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) + +prime_agent = LlmAgent( + name="prime_agent", + description="Checks whether provided numbers are prime.", + model="gemini-2.0-flash", + instruction=( + """ + You determine if numbers are prime.\n + Whenever the user asks about prime numbers, call the `check_prime` tool + with a list of integers and return its result. Never attempt to compute + primes manually – always rely on the tool. + """ + ), + tools=[check_prime], + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) + +# ----------------------------------------------------------------------------- +# Predicate used by the ConditionalAgent +# ----------------------------------------------------------------------------- + +def is_roll_request(ctx: InvocationContext) -> bool: + """Return True if the last user message seems to be a *roll* request.""" + if not ctx.user_content or not ctx.user_content.parts: + return False + text = cast(str, getattr(ctx.user_content.parts[0], "text", "")).lower() + return "roll" in text + +# ----------------------------------------------------------------------------- +# Root ConditionalAgent +# ----------------------------------------------------------------------------- + +root_agent = ConditionalAgent( + name="simple_conditional_agent", + description="Routes to roll or prime agent based on user's intent.", + sub_agents=[roll_agent, prime_agent], + condition=is_roll_request, +) diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index e1f773c47..eb5f0151a 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -19,6 +19,7 @@ from .llm_agent import LlmAgent from .loop_agent import LoopAgent from .parallel_agent import ParallelAgent +from .conditional_agent import ConditionalAgent from .run_config import RunConfig from .sequential_agent import SequentialAgent @@ -28,5 +29,6 @@ 'LlmAgent', 'LoopAgent', 'ParallelAgent', + 'ConditionalAgent', 'SequentialAgent', ] diff --git a/src/google/adk/agents/conditional_agent.py b/src/google/adk/agents/conditional_agent.py new file mode 100644 index 000000000..ed7090247 --- /dev/null +++ b/src/google/adk/agents/conditional_agent.py @@ -0,0 +1,89 @@ +"""Conditional agent implementation. + +This agent evaluates a provided predicate (condition) against the current +`InvocationContext` and delegates the call to one of *exactly two* sub-agents +based on the result: + + * ``condition(ctx) == True`` -> ``sub_agents[0]`` ("true" agent) + * ``condition(ctx) == False`` -> ``sub_agents[1]`` ("false" agent) + +Typical usages include simple routing / decision making, e.g. directing the +conversation to either a *search* agent or *calculation* agent depending on the +user request type. + +Example: + +```python +router = ConditionalAgent( + name="router", + description="Route to calc or search agent based on context flag", + sub_agents=[calc_agent, search_agent], + condition=lambda ctx: ctx.state.get("needs_calc", False), +) +``` + +The condition can be either synchronous or asynchronous. If the predicate is +awaitable (i.e. returns an ``Awaitable[bool]``), it will be awaited. +""" +from __future__ import annotations + +import inspect +from typing import Awaitable, Callable, AsyncGenerator, Union + +from typing_extensions import override + +from ..events.event import Event +from .base_agent import BaseAgent +from .invocation_context import InvocationContext + +# Type alias for the predicate +Condition = Callable[[InvocationContext], Union[bool, Awaitable[bool]]] + + +class ConditionalAgent(BaseAgent): + """A shell agent that chooses between two sub-agents based on a condition.""" + + # NOTE: The predicate function itself is **not** serialisable; exclude from + # model dump to avoid pydantic complaining when exporting to json. + condition: Condition # type: ignore[assignment] + + # --------------------------------------------------------------------- + # Internal helpers + # --------------------------------------------------------------------- + async def _evaluate_condition(self, ctx: InvocationContext) -> bool: + """Evaluates *self.condition* and returns the boolean result.""" + result = self.condition(ctx) + if inspect.isawaitable(result): + result = await result # type: ignore[assignment] + return bool(result) + + def _validate_sub_agent_count(self) -> None: + if len(self.sub_agents) != 2: + raise ValueError( + "ConditionalAgent requires *exactly* two sub-agents (true / false)." + ) + + # --------------------------------------------------------------------- + # Core execution implementations + # --------------------------------------------------------------------- + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Delegates to either ``sub_agents[0]`` or ``sub_agents[1]``.""" + self._validate_sub_agent_count() + chosen_agent_idx = 0 if await self._evaluate_condition(ctx) else 1 + chosen_agent = self.sub_agents[chosen_agent_idx] + async for event in chosen_agent.run_async(ctx): + yield event + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Live mode implementation mirroring the async version.""" + self._validate_sub_agent_count() + chosen_agent_idx = 0 if await self._evaluate_condition(ctx) else 1 + chosen_agent = self.sub_agents[chosen_agent_idx] + async for event in chosen_agent.run_live(ctx): + yield event diff --git a/tests/unittests/agents/test_conditional_agent.py b/tests/unittests/agents/test_conditional_agent.py new file mode 100644 index 000000000..59f166515 --- /dev/null +++ b/tests/unittests/agents/test_conditional_agent.py @@ -0,0 +1,133 @@ +"""Tests for the ConditionalAgent.""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator + +import pytest +from typing_extensions import override + +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.conditional_agent import ConditionalAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events import Event +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types + + +class _TestingAgent(BaseAgent): + """A simple testing agent that emits a single event.""" + + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text=f"Hello, async {self.name}!")]), + ) + + @override + async def _run_live_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + yield Event( + author=self.name, + invocation_id=ctx.invocation_id, + content=types.Content(parts=[types.Part(text=f"Hello, live {self.name}!")]), + ) + + +async def _create_parent_invocation_context( + test_name: str, agent: BaseAgent +) -> InvocationContext: + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name="test_app", user_id="test_user" + ) + return InvocationContext( + invocation_id=f"{test_name}_invocation_id", + agent=agent, + session=session, + session_service=session_service, + ) + + +# --------------------------------------------------------------------------- +# Async tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.parametrize("predicate,expected_idx", [(lambda _ctx: True, 0), (lambda _ctx: False, 1)]) +async def test_run_async_branching(request: pytest.FixtureRequest, predicate, expected_idx): + agents = [ + _TestingAgent(name=f"{request.function.__name__}_agent_true"), + _TestingAgent(name=f"{request.function.__name__}_agent_false"), + ] + conditional_agent = ConditionalAgent( + name=f"{request.function.__name__}_conditional_agent", + sub_agents=agents, + condition=predicate, + ) + parent_ctx = await _create_parent_invocation_context(request.function.__name__, conditional_agent) + events = [e async for e in conditional_agent.run_async(parent_ctx)] + + # Only the chosen agent should produce events + assert len(events) == 1 + assert events[0].author == agents[expected_idx].name + assert events[0].content.parts[0].text == f"Hello, async {agents[expected_idx].name}!" + + +@pytest.mark.asyncio +async def test_run_async_async_predicate(request: pytest.FixtureRequest): + async def async_predicate(_ctx): + await asyncio.sleep(0.01) + return True + + agent_true = _TestingAgent(name=f"{request.function.__name__}_agent_true") + agent_false = _TestingAgent(name=f"{request.function.__name__}_agent_false") + + conditional_agent = ConditionalAgent( + name=f"{request.function.__name__}_conditional_agent", + sub_agents=[agent_true, agent_false], + condition=async_predicate, + ) + parent_ctx = await _create_parent_invocation_context(request.function.__name__, conditional_agent) + events = [e async for e in conditional_agent.run_async(parent_ctx)] + + assert len(events) == 1 and events[0].author == agent_true.name + + +# --------------------------------------------------------------------------- +# Live tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_run_live_branching(request: pytest.FixtureRequest): + agent_true = _TestingAgent(name=f"{request.function.__name__}_agent_true") + agent_false = _TestingAgent(name=f"{request.function.__name__}_agent_false") + conditional_agent = ConditionalAgent( + name=f"{request.function.__name__}_conditional_agent", + sub_agents=[agent_true, agent_false], + condition=lambda _ctx: False, + ) + parent_ctx = await _create_parent_invocation_context(request.function.__name__, conditional_agent) + events = [e async for e in conditional_agent.run_live(parent_ctx)] + + assert len(events) == 1 + assert events[0].author == agent_false.name + assert events[0].content.parts[0].text == f"Hello, live {agent_false.name}!" + + +# --------------------------------------------------------------------------- +# Validation tests +# --------------------------------------------------------------------------- + + +def test_invalid_sub_agent_count(): + with pytest.raises(ValueError): + ConditionalAgent(name="invalid", sub_agents=[], condition=lambda _ctx: True)