8000 fix: propagate grounding and citation metadata in streaming responses · google/adk-python@e6da417 · GitHub
[go: up one dir, main page]

Skip to content

Commit e6da417

Browse files
sasha-gitgcopybara-github
authored andcommitted
fix: propagate grounding and citation metadata in streaming responses
Co-authored-by: Sasha Sobran <asobran@google.com> PiperOrigin-RevId: 868324488
1 parent 6ee5126 commit e6da417

File tree

3 files changed

+448
-0
lines changed

3 files changed

+448
-0
lines changed

src/google/adk/utils/streaming_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def __init__(self) -> None:
3636
self._text = ''
3737
self._thought_text = ''
3838
self._usage_metadata = None
39+
self._grounding_metadata: Optional[types.GroundingMetadata] = None
40+
self._citation_metadata: Optional[types.CitationMetadata] = None
3941
self._response = None
4042

4143
# For progressive SSE streaming mode: accumulate parts in order
@@ -251,6 +253,10 @@ async def process_response(
251253
self._response = response
252254
llm_response = LlmResponse.create(response)
253255
self._usage_metadata = llm_response.usage_metadata
256+
if llm_response.grounding_metadata:
257+
self._grounding_metadata = llm_response.grounding_metadata
258+
if llm_response.citation_metadata:
259+
self._citation_metadata = llm_response.citation_metadata
254260

255261
# ========== Progressive SSE Streaming (new feature) ==========
256262
# Save finish_reason for final aggregation
@@ -347,6 +353,8 @@ def close(self) -> Optional[LlmResponse]:
347353

348354
return LlmResponse(
349355
content=types.ModelContent(parts=final_parts),
356+
grounding_metadata=self._grounding_metadata,
357+
citation_metadata=self._citation_metadata,
350358
error_code=None
351359
if finish_reason == types.FinishReason.STOP
352360
else finish_reason,
@@ -374,6 +382,8 @@ def close(self) -> Optional[LlmResponse]:
374382
candidate = self._response.candidates[0]
375383
return LlmResponse(
376384
content=types.ModelContent(parts=parts),
385+
grounding_metadata=self._grounding_metadata,
386+
citation_metadata=self._citation_metadata,
377387
error_code=None
378388
if candidate.finish_reason == types.FinishReason.STOP
379389
else candidate.finish_reason,
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Integration tests for grounding metadata preservation in SSE streaming.
16+
17+
Verifies that grounding_metadata from VertexAiSearchTool reaches the final
18+
non-partial event in both progressive and non-progressive SSE streaming modes.
19+
20+
Prerequisites:
21+
- GOOGLE_CLOUD_PROJECT env var set to a GCP project with Vertex AI enabled
22+
- Discovery Engine API enabled (discoveryengine.googleapis.com)
23+
- Authenticated via `gcloud auth application-default login`
24+
25+
Usage:
26+
GOOGLE_CLOUD_PROJECT=my-project pytest
27+
tests/integration/test_vertex_ai_search_grounding_streaming.py -v -s
28+
"""
29+
30+
from __future__ import annotations
31+
32+
import json
33+
import os
34+
import time
35+
import uuid
36+
37+
from google.adk.features._feature_registry import FeatureName
38+
from google.adk.features._feature_registry import temporary_feature_override
39+
from google.genai import types
40+
import pytest
41+
42+
_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT", "")
43+
_LOCATION = os.environ.get("GOOGLE_CLOUD_LOCATION", "global")
44+
_COLLECTION = "default_collection"
45+
_DATA_STORE_ID = f"adk-grounding-test-{uuid.uuid4().hex[:8]}"
46+
_DATA_STORE_DISPLAY_NAME = "ADK Grounding Integration Test"
47+
_MODEL = "gemini-2.0-flash"
48+
49+
_TEST_DOCUMENTS = (
50+
{
51+
"id": "doc-adk-overview",
52+
"title": "ADK Overview",
53+
"content": (
54+
"The Agent Development Kit (ADK) is an open-source framework by"
55+
" Google for building AI agents. ADK supports multi-agent"
56+
" architectures, tool use, and integrates with Gemini models."
57+
" ADK was first released in April 2025."
58+
),
59+
},
60+
{
61+
"id": "doc-adk-tools",
62+
"title": "ADK Built-in Tools",
63+
"content": (
64+
"ADK provides built-in tools including VertexAiSearchTool for"
65+
" grounded search, GoogleSearchTool for web search, and"
66+
" CodeExecutionTool for running code. The VertexAiSearchTool"
67+
" returns grounding metadata with citations pointing to source"
68+
" documents."
69+
),
70+
},
71+
)
72+
73+
74+
def _parent_path() -> str:
75+
return f"projects/{_PROJECT}/locations/{_LOCATION}/collections/{_COLLECTION}"
76+
77+
78+
def _data_store_path() -> str:
79+
return f"{_parent_path()}/dataStores/{_DATA_STORE_ID}"
80+
81+
82+
@pytest.fixture(scope="module")
83+
def project_id():
84+
if not _PROJECT:
85+
pytest.skip("GOOGLE_CLOUD_PROJECT env var not set")
86+
return _PROJECT
87+
88+
89+
@pytest.fixture(scope="module")
90+
def data_store_resource(project_id) -> str:
91+
"""Create a Vertex AI Search data store with test documents."""
92+
from google.api_core.exceptions import AlreadyExists
93+
from google.cloud import discoveryengine_v1beta as discoveryengine
94+
95+
ds_client = discoveryengine.DataStoreServiceClient()
96+
doc_client = discoveryengine.DocumentServiceClient()
97+
98+
# Create data store
99+
try:
100+
request = discoveryengine.CreateDataStoreRequest(
101+
parent=_parent_path(),
102+
data_store=discoveryengine.DataStore(
103+
display_name=_DATA_STORE_DISPLAY_NAME,
104+
industry_vertical=discoveryengine.IndustryVertical.GENERIC,
105+
solution_types=[discoveryengine.SolutionType.SOLUTION_TYPE_SEARCH],
106+
content_config=discoveryengine.DataStore.ContentConfig.NO_CONTENT,
107+
),
108+
data_store_id=_DATA_STORE_ID,
109+
)
110+
operation = ds_client.create_data_store(request=request)
111+
print(f"\nCreating data store '{_DATA_STORE_ID}'...")
112+
operation.result(timeout=120)
113+
print("Data store created.")
114+
except AlreadyExists:
115+
print(f"\nData store '{_DATA_STORE_ID}' already exists, reusing.")
116+
117+
# Ingest test documents
118+
branch = f"{_data_store_path()}/branches/default_branch"
119+
for doc_data in _TEST_DOCUMENTS:
120+
json_data = json.dumps({
121+
"title": doc_data["title"],
122+
"description": doc_data["content"],
123+
})
124+
doc = discoveryengine.Document(
125+
id=doc_data["id"],
126+
json_data=json_data,
127+
)
128+
try:
129+
doc_client.create_document(
130+
parent=branch,
131+
document=doc,
132+
document_id=doc_data["id"],
133+
)
134+
print(f" Created document: {doc_data['id']}")
135+
except AlreadyExists:
136+
doc_client.update_document(
137+
document=discoveryengine.Document(
138+
name=f"{branch}/documents/{doc_data['id']}",
139+
json_data=json_data,
140+
),
141+
)
142+
print(f" Updated document: {doc_data['id']}")
143+
144+
print("Waiting 5s for indexing...")
145+
time.sleep(5)
146+
147+
yield _data_store_path()
148+
149+
# Cleanup — best-effort, ignore errors from Discovery Engine LRO
150+
try:
151+
operation = ds_client.delete_data_store(name=_data_store_path())
152+
operation.result(timeout=120)
153+
print(f"\nDeleted data store '{_DATA_STORE_ID}'.")
154+
except Exception as e:
155+
print(f"\nFailed to delete data store '{_DATA_STORE_ID}': {e}")
156+
157+
158+
class TestIntegrationVertexAiSearchGrounding:
159+
"""Integration tests hitting real Vertex AI with VertexAiSearchTool."""
160+
161+
@pytest.mark.parametrize("llm_backend", ["VERTEX"], indirect=True)
162+
@pytest.mark.parametrize(
163+
"progressive_sse, label",
164+
[
165+
(True, "Progressive SSE"),
166+
(False, "Non-Progressive SSE"),
167+
],
168+
)
169+
@pytest.mark.asyncio
170+
async def test_grounding_metadata_with_sse_streaming(
171+
self, project_id, data_store_resource, progressive_sse, label
172+
):
173+
"""Verifies grounding_metadata in SSE streaming modes."""
174+
from google.adk.agents.llm_agent import LlmAgent
175+
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
176+
177+
agent = LlmAgent(
178+
name="test_agent",
179+
model=_MODEL,
180+
tools=[VertexAiSearchTool(data_store_id=data_store_resource)],
181+
instruction="Answer questions using the search tool.",
182+
)
183+
184+
with temporary_feature_override(
185+
FeatureName.PROGRESSIVE_SSE_STREAMING, progressive_sse
186+
):
187+
all_events, saved_events = await self._run_agent_streaming(
188+
agent, project_id
189+
)
190+
191+
self._report_events(label, all_events, saved_events)
192+
193+
saved_with_grounding = [e for e in saved_events if e["has_grounding"]]
194+
assert (
195+
saved_with_grounding
196+
), f"No saved (non-partial) events have grounding_metadata with {label}."
197+
198+
@pytest.mark.parametrize("llm_backend", ["VERTEX"], indirect=True)
199+
@pytest.mark.asyncio
200+
async def test_grounding_metadata_without_streaming(
201+
self, project_id, data_store_resource
202+
):
203+
"""Without streaming, grounding_metadata should always be present."""
204+
from google.adk.agents.llm_agent import LlmAgent
205+
from google 2625 .adk.agents.run_config import RunConfig
206+
from google.adk.agents.run_config import StreamingMode
207+
from google.adk.runners import Runner
208+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
209+
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
210+
from google.adk.utils.context_utils import Aclosing
211+
212+
agent = LlmAgent(
213+
name="test_agent",
214+
model=_MODEL,
215+
tools=[VertexAiSearchTool(data_store_id=data_store_resource)],
216+
instruction="Answer questions using the search tool.",
217+
)
218+
219+
session_service = InMemorySessionService()
220+
runner = Runner(
221+
app_name="test_app",
222+
agent=agent,
223+
session_service=session_service,
224+
)
225+
session = await session_service.create_session(
226+
app_name="test_app", user_id="test_user"
227+
)
228+
229+
run_config = RunConfig(streaming_mode=StreamingMode.NONE)
230+
events = []
231+
async with Aclosing(
232+
runner.run_async(
233+
user_id="test_user",
234+
session_id=session.id,
235+
new_message=types.Content(
236+
role="user",
237+
parts=[
238+
types.Part.from_text(
239+
text="What built-in tools does ADK provide?"
240+
)
241+
],
242+
),
243+
run_config=run_config,
244+
)
245+
) as agen:
246+
async for event in agen:
247+
events.append({
248+
"author": event.author,
249+
"partial": event.partial,
250+
"has_grounding": event.grounding_metadata is not None,
251+
"has_content": bool(event.content and event.content.parts),
252+
})
253+
254+
print("\n=== No Streaming ===")
255+
for i, e in enumerate(events):
256+
print(
257+
f" Event {i}: author={e['author']}, partial={e['partial']},"
258+
f" grounding={e['has_grounding']}, content={e['has_content']}"
259+
)
260+
261+
model_events = [e for e in events if e["author"] == "test_agent"]
262+
with_grounding = [e for e in model_events if e["has_grounding"]]
263+
assert (
264+
with_grounding
265+
), "No events have grounding_metadata even without streaming."
266+
267+
async def _run_agent_streaming(self, agent, project_id):
268+
from google.adk.agents.run_config import RunConfig
269+
from google.adk.agents.run_config import StreamingMode
270+
from google.adk.runners import Runner
271+
from google.adk.sessions.in_memory_session_service import InMemorySessionService
272+
from google.adk.utils.context_utils import Aclosing
273+
274+
session_service = InMemorySessionService()
275+
runner = Runner(
276+
app_name="test_app",
277+
agent=agent,
278+
session_service=session_service,
279+
)
280+
session = await session_service.create_session(
281+
app_name="test_app", user_id="test_user"
282+
)
283+
284+
run_config = RunConfig(streaming_mode=StreamingMode.SSE)
285+
all_events = []
286+
async with Aclosing(
287+
runner.run_async(
288+
user_id="test_user",
289+
session_id=session.id,
290+
new_message=types.Content(
291+
role="user",
292+
parts=[
293+
types.Part.from_text(
294+
text="What is ADK and when was it first released?"
295+
)
296+
],
297+
),
298+
run_config=run_config,
299+
)
300+
) as agen:
301+
async for event in agen:
302+
all_events.append({
303+
"author": event.author,
304+
"partial": event.partial,
305+
"has_grounding": event.grounding_metadata is not None,
306+
"has_content": bool(event.content and event.content.parts),
307+
})
308+
309+
saved_events = [e for e in all_events if e["partial"] is not True]
310+
return all_events, saved_events
311+
312+
def _report_events(self, label, all_events, saved_events):
313+
print(f"\n=== {label} — All Events ===")
314+
for i, e in enumerate(all_events):
315+
print(
316+
f" Event {i}: author={e['author']}, partial={e['partial']},"
317+
f" grounding={e['has_grounding']},"
318+
f" content={e['has_content']}"
319+
)
320+
print(f"\n=== {label} — Saved (non-partial) Events ===")
321+
for i, e in enumerate(saved_events):
322+
print(
323+
f" Event {i}: author={e['author']}, partial={e['partial']},"
324+
f" grounding={e['has_grounding']},"
325+
f" content={e['has_content']}"
326+
)
327+
partial_with_grounding = [
328+
e for e in all_events if e["partial"] is True and e["has_grounding"]
329+
]
330+
if partial_with_grounding:
331+
print(
332+
f"\n NOTE: {len(partial_with_grounding)} partial event(s)"
333+
" had grounding_metadata but were NOT saved to session."
334+
)

0 commit comments

Comments
 (0)
0