8000 feat: support realtime input config · google/adk-python@d22920b · GitHub
[go: up one dir, main page]

Skip to content

Commit d22920b

Browse files
ammmrcopybara-github
authored andcommitted
feat: support realtime input config
Merge #981 issue: #982 This pull request introduces a new configuration option, `realtime_input_config`, to the `RunConfig` class. **Reason for this change:** Currently, there is no direct way to configure real-time audio input behaviors, such as Voice Activity Detection (VAD), for live agents through the `RunConfig`. The Gemini API documentation (specifically [Configure automatic VAD](https://ai.google.dev/gemini-api/docs/live#configure-automatic-vad)) outlines parameters for VAD that users may want to customize. This change enables users to pass these real-time input configurations, providing more granular control over the audio input for live agents. **Changes made:** - Added a new optional field `realtime_input_config: Optional[types.RealtimeInputConfig]` to the `RunConfig` class. - The docstring for `realtime_input_config` has been added to explain its purpose. **Example Usage (Conceptual):** While the specific structure of `types.RealtimeInputConfig` would define the exact parameters, a user might configure it like this: ```python # (Assuming types.RealtimeInputConfig and types.VadConfig are defined elsewhere) # import your_project.types as types run_config = RunConfig( # ... other configurations ... realtime_input_config=types.RealtimeInputConfig( automatic_activity_detection =types.AutomaticActivityDetection( # VAD specific parameters like sensitivity, endpoint_duration_millis etc. # based on https://ai.google.dev/gemini-api/docs/live#configure-automatic-vad ) # Potentially other real-time input settings could be added here in the future ) ) COPYBARA_INTEGRATE_REVIEW=#981 from ammmr:patch-add-realtime-input-config b2e17fb PiperOrigin-RevId: 770797640
1 parent 2ff9b1f commit d22920b

File tree

6 files changed

+330
-3
lines changed

6 files changed

+330
-3
lines changed

src/google/adk/agents/run_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
from enum import Enum
1618
import logging
1719
import sys
@@ -68,6 +70,9 @@ class RunConfig(BaseModel):
6870
input_audio_transcription: Optional[types.AudioTranscriptionConfig] = None
6971
"""Input transcription for live agents with audio input from user."""
7072

73+
realtime_input_config: Optional[types.RealtimeInputConfig] = None
74+
"""Realtime input config for live agents with audio input from user."""
75+
7176
max_llm_calls: int = 500
7277
"""
7378
A limit on the total number of llm calls for a given run.

src/google/adk/flows/llm_flows/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ async def run_async(
6565
llm_request.live_connect_config.input_audio_transcription = (
6666
invocation_context.run_config.input_audio_transcription
6767
)
68+
llm_request.live_connect_config.realtime_input_config = (
69+
invocation_context.run_config.realtime_input_config
70+
)
6871

6972
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
7073

src/google/adk/models/gemini_llm_connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
import logging
1618
from typing import AsyncGenerator
1719

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
from google.adk.agents import Agent
18+
from google.adk.agents.live_request_queue import LiveRequest
19+
from google.adk.agents.live_request_queue import LiveRequestQueue
20+
from google.adk.agents.run_config import RunConfig
21+
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
22+
from google.adk.models.llm_request import LlmRequest
23+
from google.genai import types
24+
import pytest
25+
26+
from ... import testing_utils
27+
28+
29+
class TestBaseLlmFlow(BaseLlmFlow):
30+
"""Test implementation of BaseLlmFlow for testing purposes."""
31+
32+
pass
33+
34+
35+
@pytest.fixture
36+
def test_blob():
37+
"""Test blob for audio data."""
38+
return types.Blob(data=b'\x00\xFF\x00\xFF', mime_type='audio/pcm')
39+
40+
41+
@pytest.fixture
42+
def mock_llm_connection():
43+
"""Mock LLM connection for testing."""
44+
connection = mock.AsyncMock()
45+
connection.send_realtime = mock.AsyncMock()
46+
return connection
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_send_to_model_with_disabled_vad(test_blob, mock_llm_connection):
51+
"""Test _send_to_model with automatic_activity_detection.disabled=True."""
52+
# Create LlmRequest with disabled VAD
53+
realtime_input_config = types.RealtimeInputConfig(
54+
automatic_activity_detection=types.AutomaticActivityDetection(
55+
disabled=True
56+
)
57+
)
58+
59+
# Create invocation context with live request queue
60+
agent = Agent(name='test_agent', model='mock')
61+
invocation_context = await testing_utils.create_invocation_context(
62+
agent=agent,
63+
user_content='',
64+
run_config=RunConfig(realtime_input_config=realtime_input_config),
65+
)
66+
invocation_context.live_request_queue = LiveRequestQueue()
67+
68+
# Create flow and start _send_to_model task
69+
flow = TestBaseLlmFlow()
70+
71+
# Send a blob to the queue
72+
live_request = LiveRequest(blob=test_blob)
73+
invocation_context.live_request_queue.send(live_request)
74+
invocation_context.live_request_queue.close()
75+
76+
# Run _send_to_model
77+
await flow._send_to_model(mock_llm_connection, invocation_context)
78+
79+
mock_llm_connection.send_realtime.assert_called_once_with(test_blob)
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_send_to_model_with_enabled_vad(test_blob, mock_llm_connection):
84+
"""Test _send_to_model with automatic_activity_detection.disabled=False.
85+
86+
Custom VAD activity signal is not supported so we should still disable it.
87+
"""
88+
# Create LlmRequest with enabled VAD
89+
realtime_input_config = types.RealtimeInputConfig(
90+
automatic_activity_detection=types.AutomaticActivityDetection(
91+
disabled=False
92+
)
93+
)
94+
95+
# Create invocation context with live request queue
96+
agent = Agent(name='test_agent', model='mock')
97+
invocation_context = await testing_utils.create_invocation_context(
98+
agent=agent, user_content=''
99+
)
100+
invocation_context.live_request_queue = LiveRequestQueue()
101+
102+
# Create flow and start _send_to_model task
103+
flow = TestBaseLlmFlow()
104+
105+
# Send a blob to the queue
106+
live_request = LiveRequest(blob=test_blob)
107+
invocation_context.live_request_queue.send(live_request)
108+
invocation_context.live_request_queue.close()
109+
110+
# Run _send_to_model
111+
await flow._send_to_model(mock_llm_connection, invocation_context)
112+
113+
mock_llm_connection.send_realtime.assert_called_once_with(test_blob)
114+
115+
116+
@pytest.mark.asyncio
117+
async def test_send_to_model_without_realtime_config(
118+
test_blob, mock_llm_connection
119+
):
120+
"""Test _send_to_model without realtime_input_config (default behavior)."""
121+
# Create invocation context with live request queue
122+
agent = Agent(name='test_agent', model='mock')
123+
invocation_context = await testing_utils.create_invocation_context(
124+
agent=agent, user_content=''
125+
)
126+
invocation_context.live_request_queue = LiveRequestQueue()
127+
128+
# Create flow and start _send_to_model task
129+
flow = TestBaseLlmFlow()
130+
131+
# Send a blob to the queue
132+
live_request = LiveRequest(blob=test_blob)
133+
invocation_context.live_request_queue.send(live_request)
134+
invocation_context.live_request_queue.close()
135+
136+
# Run _send_to_model
137+
await flow._send_to_model(mock_llm_connection, invocation_context)
138+
139+
mock_llm_connection.send_realtime.assert_called_once_with(test_blob)
140+
141+
142+
@pytest.mark.asyncio
143+
async def test_send_to_model_with_none_automatic_activity_detection(
144+
test_blob, mock_llm_connection
145+
):
146+
"""Test _send_to_model with automatic_activity_detection=None."""
147+
# Create LlmRequest with None automatic_activity_detection
148+
realtime_input_config = types.RealtimeInputConfig(
149+
automatic_activity_detection=None
150+
)
151+
152+
# Create invocation context with live request queue
153+
agent = Agent(name='test_agent', model='mock')
154+
invocation_context = await testing_utils.create_invocation_context(
155+
agent=agent,
156+
user_content='',
157+
run_config=RunConfig(realtime_input_config=realtime_input_config),
158+
)
159+
invocation_context.live_request_queue = LiveRequestQueue()
160+
161+
# Create flow and start _send_to_model task
162+
flow = TestBaseLlmFlow()
163+
164+
# Send a blob to the queue
165+
live_request = LiveRequest(blob=test_blob)
166+
invocation_context.live_request_queue.send(live_request)
167+
invocation_context.live_request_queue.close()
168+
169+
# Run _send_to_model
170+
await flow._send_to_model(mock_llm_connection, invocation_context)
171+
172+
mock_llm_connection.send_realtime.assert_called_once_with(test_blob)
173+
174+
175+
@pytest.mark.asyncio
176+
async def test_send_to_model_with_text_content(mock_llm_connection):
177+
"""Test _send_to_model with text content (not blob)."""
178+
# Create invocation context with live request queue
179+
agent = Agent(name='test_agent', model='mock')
180+
invocation_context = await testing_utils.create_invocation_context(
181+
agent=agent, user_content=''
182+
)
183+
invocation_context.live_request_queue = LiveRequestQueue()
184+
185+
# Create flow and start _send_to_model task
186+
flow = TestBaseLlmFlow()
187+
188+
# Send text content to the queue
189+
content = types.Content(
190+
role='user', parts=[types.Part.from_text(text='Hello')]
191+
)
192+
live_request = LiveRequest(content=content)
193+
invocation_context.live_request_queue.send(live_request)
194+
invocation_context.live_request_queue.close()
195+
196+
# Run _send_to_model
197+
await flow._send_to_model(mock_llm_connection, invocation_context)
198+
199+
# Verify send_content was called instead of send_realtime
200+
mock_llm_connection.send_content.assert_called_once_with(content)
201+
mock_llm_connection.send_realtime.assert_not_called()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
from google.adk.models.gemini_llm_connection import GeminiLlmConnection
18+
from google.genai import types
19+
import pytest
20+
21+
22+
@pytest.fixture
23+
def mock_gemini_session():
24+
"""Mock Gemini session for testing."""
25+
return mock.AsyncMock()
26+
27+
28+
@pytest.fixture
29+
def gemini_connection(mock_gemini_session):
30+
"""GeminiLlmConnection instance with mocked session."""
31+
return GeminiLlmConnection(mock_gemini_session)
32+
33+
34+
@pytest.fixture
35+
def test_blob():
36+
"""Test blob for audio data."""
37+
return types.Blob(data=b'\x00\xFF\x00\xFF', mime_type='audio/pcm')
38+
39+
40+
@pytest.mark.asyncio
41+
async def test_send_realtime_default_behavior(
42+
gemini_connection, mock_gemini_session, test_blob
43+
):
44+
"""Test send_realtime with default automatic_activity_detection value (True)."""
45+
await gemini_connection.send_realtime(test_blob)
46+
47+
# Should call send once
48+
mock_gemini_session.send.assert_called_once_with(input=test_blob.model_dump())
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_send_history(gemini_connection, mock_gemini_session):
53+
"""Test send_history method."""
54+
history = [
55+
types.Content(role='user', parts=[types.Part.from_text(text='Hello')]),
56+
types.Content(
57+
role='model', parts=[types.Part.from_text(text='Hi there!')]
58+
),
59+
]
60+
61+
await gemini_connection.send_history(history)
62+
63+
mock_gemini_session.send.assert_called_once()
64+
call_args = mock_gemini_session.send.call_args[1]
65+
assert 'input' in call_args
66+
assert call_args['input'].turns == history
67+
assert call_args['input'].turn_complete is False # Last message is from model
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_send_content_text(gemini_connection, mock_gemini_session):
72+
"""Test send_content with text content."""
73+
content = types.Content(
74+
role='user', parts=[types.Part.from_text(text='Hello')]
75+
)
76+
77+
await gemini_connection.send_content(content)
78+
79+
mock_gemini_session.send.assert_called_once()
80+
call_args = mock_gemini_session.send.call_args[1]
81+
assert 'input' in call_args
82+
assert call_args['input'].turns == [content]
83+
assert call_args['input'].turn_complete is True
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_send_content_function_response(
88+
gemini_connection, mock_gemini_session
89+
):
90+
"""Test send_content with function response."""
91+
function_response = types.FunctionResponse(
92+
name='test_function', response={'result': 'success'}
93+
)
94+
content = types.Content(
95+
role='user', parts=[types.Part(function_response=function_response)]
96+
)
97+
98+
await gemini_connection.send_content(content)
99+
100+
mock_gemini_session.send.assert_called_once()
101+
call_args = mock_gemini_session.send.call_args[1]
102+
assert 'input' in call_args
103+
assert call_args['input'].function_responses == [function_response]
104+
105+
106+
@pytest.mark.asyncio
107+
async def test_close(gemini_connection, mock_gemini_session):
108+
"""Test close method."""
109+
await gemini_connection.close()
110+
111+
mock_gemini_session.close.assert_called_once()

tests/unittests/testing_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def __init__(self, parts: list[types.Part]):
5656
super().__init__(role='model', parts=parts)
5757

5858

59-
async def create_invocation_context(agent: Agent, user_content: str = ''):
59+
async def create_invocation_context(
60+
agent: Agent, user_content: str = '', run_config: RunConfig = None
61+
):
6062
invocation_id = 'test_id'
6163
artifact_service = InMemoryArtifactService()
6264
session_service = InMemorySessionService()
@@ -73,7 +75,7 @@ async def create_invocation_context(agent: Agent, user_content: str = ''):
7375
user_content=types.Content(
7476
role='user', parts=[types.Part.from_text(text=user_content)]
7577
),
76-
run_config=RunConfig(),
78+
run_config=run_config or RunConfig(),
7779
)
7880
if user_content:
7981
append_user_content(
@@ -205,13 +207,16 @@ async def run_async(self, new_message: types.ContentUnion) -> list[Event]:
205207
events.append(event)
206208
return events
207209

208-
def run_live(self, live_request_queue: LiveRequestQueue) -> list[Event]:
210+
def run_live(
211+
self, live_request_queue: LiveRequestQueue, run_config: RunConfig = None
212+
) -> list[Event]:
209213
collected_responses = []
210214

211215
async def consume_responses(session: Session):
212216
run_res = self.runner.run_live(
213217
session=session,
214218
live_request_queue=live_request_queue,
219+
run_config=run_config or RunConfig(),
215220
)
216221

217222
async for response in run_res:

0 commit comments

Comments
 (0)
0