8000 feat: Support dynamic config for VertexAiSearchTool · google/adk-python@585ebfd · GitHub
[go: up one dir, main page]

Skip to content

Commit 585ebfd

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Support dynamic config for VertexAiSearchTool
Close: #4067 Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 862965769
1 parent c0c98d9 commit 585ebfd

File tree

2 files changed

+165
-14
lines changed

src/google/adk/tools/vertex_ai_search_tool.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.genai import types
2222
from typing_extensions import override
2323

24+
from ..agents.readonly_context import ReadonlyContext
2425
from ..utils.model_name_utils import is_gemini_1_model
2526
from ..utils.model_name_utils import is_gemini_model
2627
from .base_tool import BaseTool
@@ -38,6 +39,25 @@ class VertexAiSearchTool(BaseTool):
3839
Attributes:
3940
data_store_id: The Vertex AI search data store resource ID.
4041
search_engine_id: The Vertex AI search engine resource ID.
42+
43+
To dynamically customize the search configuration at runtime (e.g., set
44+
filter based on user context), subclass this tool and override the
45+
`_build_vertex_ai_search_config` method.
46+
47+
Example:
48+
```python
49+
class DynamicFilterSearchTool(VertexAiSearchTool):
50+
def _build_vertex_ai_search_config(
51+
self, ctx: ReadonlyContext
52+
) -> types.VertexAISearch:
53+
user_id = ctx.state.get('user_id')
54+
return types.VertexAISearch(
55+
datastore=self.data_store_id,
56+
engine=self.search_engine_id,
57+
filter=f"user_id = '{user_id}'",
58+
max_results=self.max_results,
59+
)
60+
```
4161
"""
4262

4363
def __init__(
@@ -90,6 +110,30 @@ def __init__(
90110
self.max_results = max_results
91111
self.bypass_multi_tools_limit = bypass_multi_tools_limit
92112

113+
def _build_vertex_ai_search_config(
114+
self, readonly_context: ReadonlyContext
115+
) -> types.VertexAISearch:
116+
"""Builds the VertexAISearch configuration.
117+
118+
Override this method in a subclass to dynamically customize the search
119+
configuration based on the context (e.g., set filter based on session
120+
state).
121+
122+
Args:
123+
readonly_context: The readonly context with access to state and session
124+
info.
125+
126+
Returns:
127+
The VertexAISearch configuration to use for this request.
128+
"""
129+
return types.VertexAISearch(
130+
datastore=self.data_store_id,
131+
data_store_specs=self.data_store_specs,
132+
engine=self.search_engine_id,
133+
filter=self.filter,
134+
max_results=self.max_results,
135+
)
136+
93137
@override
94138
async def process_llm_request(
95139
self,
@@ -106,14 +150,20 @@ async def process_llm_request(
106150
llm_request.config = llm_request.config or types.GenerateContentConfig()
107151
llm_request.config.tools = llm_request.config.tools or []
108152

153+
# Build the search config (can be overridden by subclasses)
154+
vertex_ai_search_config = self._build_vertex_ai_search_config(
155+
tool_context
156+
)
157+
109158
# Format data_store_specs concisely for logging
110-
if self.data_store_specs:
159+
if vertex_ai_search_config.data_store_specs:
111160
spec_ids = [
112161
spec.data_store.split('/')[-1] if spec.data_store else 'unnamed'
113-
for spec in self.data_store_specs
162+
for spec in vertex_ai_search_config.data_store_specs
114163
]
115164
specs_info = (
116-
f'{len(self.data_store_specs)} spec(s): [{", ".join(spec_ids)}]'
165+
f'{len(vertex_ai_search_config.data_store_specs)} spec(s):'
166+
f' [{", ".join(spec_ids)}]'
117167
)
118168
else:
119169
specs_info = None
@@ -122,23 +172,17 @@ async def process_llm_request(
122172
'Adding Vertex AI Search tool config to LLM request: '
123173
'datastore=%s, engine=%s, filter=%s, max_results=%s, '
124174
'data_store_specs=%s',
125-
self.data_store_id,
126-
self.search_engine_id,
127-
self.filter,
128-
self.max_results,
175+
vertex_ai_search_config.datastore,
176+
vertex_ai_search_config.engine,
177+
vertex_ai_search_config.filter,
178+
vertex_ai_search_config.max_results,
129179
specs_info,
130180
)
131181

132182
llm_request.config.tools.append(
133183
types.Tool(
134184
retrieval=types.Retrieval(
135-
vertex_ai_search=types.VertexAISearch(
136-
datastore=self.data_store_id,
137-
data_store_specs=self.data_store_specs,
138-
engine=self.search_engine_id,
139-
filter=self.filter,
140-
max_results=self.max_results,
141-
)
185+
vertex_ai_search=vertex_ai_search_config
142186
)
143187
)
144188
)

tests/unittests/tools/test_vertex_ai_search_tool.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,110 @@ async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds(
449449
assert 'filter=None' in log_message
450450
assert 'max_results=None' in log_message
451451
assert 'data_store_specs=None' in log_message
452+
453+
@pytest.mark.asyncio
454+
async def test_subclass_with_dynamic_filter(self):
455+
"""Test subclassing to provide dynamic filter based on context."""
456+
457+
class DynamicFilterSearchTool(VertexAiSearchTool):
458+
"""Custom search tool with dynamic filter."""
459+
460+
def _build_vertex_ai_search_config(self, ctx):
461+
user_id = ctx.state.get('user_id', 'default_user')
462+
return types.VertexAISearch(
463+
datastore=self.data_store_id,
464+
engine=self.search_engine_id,
465+
filter=f"user_id = '{user_id}'",
466+
max_results=self.max_results,
467+
)
468+
469+
tool = DynamicFilterSearchTool(data_store_id='test_data_store')
470+
tool_context = await _create_tool_context()
471+
tool_context.state['user_id'] = 'test_user_123'
472+
473+
llm_request = LlmRequest(
474+
model='gemini-2.5-pro', config=types.GenerateContentConfig()
475+
)
476+
477+
await tool.process_llm_request(
478+
tool_context=tool_context, llm_request=llm_request
479+
)
480+
481+
assert llm_request.config.tools is not None
482+
assert len(llm_request.config.tools) == 1
483+
retrieval_tool = llm_request.config.tools[0]
484+
assert retrieval_tool.retrieval is not None
485+
assert retrieval_tool.retrieval.vertex_ai_search is not None
486+
# Verify the filter was dynamically set
487+
assert (
488+
retrieval_tool.retrieval.vertex_ai_search.filter
489+
== "user_id = 'test_user_123'"
490+
)
491+
492+
@pytest.mark.asyncio
493+
async def test_subclass_with_dynamic_max_results(self):
494+
"""Test subclassing to provide dynamic max_results based on context."""
495+
496+
class DynamicMaxResultsSearchTool(VertexAiSearchTool):
497+
"""Custom search tool with dynamic max_results."""
498+
499+
def _build_vertex_ai_search_config(self, ctx):
500+
# Use a larger max_results for premium users
501+
is_premium = ctx.state.get('is_premium', False)
502+
dynamic_max_results = 20 if is_premium else 5
503+
return types.VertexAISearch(
504+
datastore=self.data_store_id,
505+
engine=self.search_engine_id,
506+
filter=self.filter,
507+
max_results=dynamic_max_results,
508+
)
509+
510+
tool = DynamicMaxResultsSearchTool(
511+
data_store_id='test_data_store', max_results=10
512+
)
513+
tool_context = await _create_tool_context()
514+
tool_context.state['is_premium'] = True
515+
516+
llm_request = LlmRequest(
517+
model='gemini-2.5-pro', config=types.GenerateContentConfig()
518+
)
519+
520+
await tool.process_llm_request(
521+
tool_context=tool_context, llm_request=llm_request
522+
)
523+
524+
retrieval_tool = llm_request.config.tools[0]
525+
# Verify max_results was dynamically set to premium value
526+
assert retrieval_tool.retrieval.vertex_ai_search.max_results == 20
527+
528+
@pytest.mark.asyncio
529+
async def test_subclass_receives_readonly_context(self):
530+
"""Test that subclass receives the context correctly."""
531+
received_contexts = []
532+
533+
class ContextCapturingSearchTool(VertexAiSearchTool):
534+
"""Custom search tool that captures the context."""
535+
536+
def _build_vertex_ai_search_config(self, ctx):
537+
received_contexts.append(ctx)
538+
return types.VertexAISearch(
539+
datastore=self.data_store_id,
540+
engine=self.search_engine_id,
541+
filter=self.filter,
542+
max_results=self.max_results,
543+
)
544+
545+
tool = ContextCapturingSearchTool(data_store_id='test_data_store')
546+
tool_context = await _create_tool_context()
547+
548+
llm_request = LlmRequest(
549+
model='gemini-2.5-pro', config=types.GenerateContentConfig()
550+
)
551+
552+
await tool.process_llm_request(
553+
tool_context=tool_context, llm_request=llm_request
554+
)
555+
556+
# Verify the context was passed to _build_vertex_ai_search_config
557+
assert len(received_contexts) == 1
558+
assert received_contexts[0] is tool_context

0 commit comments

Comments
 (0)
< 2AFB /react-app>
0