@@ -449,3 +449,110 @@ async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds(
449449 assert 'filter=None' in log_message
450450 assert 'max_results=None' in log_message
451451 assert 'data_store_specs=None' in log_message
452+
453+ @pytest .mark .asyncio
454+ async def test_subclass_with_dynamic_filter (self ):
455+ """Test subclassing to provide dynamic filter based on context."""
456+
457+ class DynamicFilterSearchTool (VertexAiSearchTool ):
458+ """Custom search tool with dynamic filter."""
459+
460+ def _build_vertex_ai_search_config (self , ctx ):
461+ user_id = ctx .state .get ('user_id' , 'default_user' )
462+ return types .VertexAISearch (
463+ datastore = self .data_store_id ,
464+ engine = self .search_engine_id ,
465+ filter = f"user_id = '{ user_id } '" ,
466+ max_results = self .max_results ,
467+ )
468+
469+ tool = DynamicFilterSearchTool (data_store_id = 'test_data_store' )
470+ tool_context = await _create_tool_context ()
471+ tool_context .state ['user_id' ] = 'test_user_123'
472+
473+ llm_request = LlmRequest (
474+ model = 'gemini-2.5-pro' , config = types .GenerateContentConfig ()
475+ )
476+
477+ await tool .process_llm_request (
478+ tool_context = tool_context , llm_request = llm_request
479+ )
480+
481+ assert llm_request .config .tools is not None
482+ assert len (llm_request .config .tools ) == 1
483+ retrieval_tool = llm_request .config .tools [0 ]
484+ assert retrieval_tool .retrieval is not None
485+ assert retrieval_tool .retrieval .vertex_ai_search is not None
486+ # Verify the filter was dynamically set
487+ assert (
488+ retrieval_tool .retrieval .vertex_ai_search .filter
489+ == "user_id = 'test_user_123'"
490+ )
491+
492+ @pytest .mark .asyncio
493+ async def test_subclass_with_dynamic_max_results (self ):
494+ """Test subclassing to provide dynamic max_results based on context."""
495+
496+ class DynamicMaxResultsSearchTool (VertexAiSearchTool ):
497+ """Custom search tool with dynamic max_results."""
498+
499+ def _build_vertex_ai_search_config (self , ctx ):
500+ # Use a larger max_results for premium users
501+ is_premium = ctx .state .get ('is_premium' , False )
502+ dynamic_max_results = 20 if is_premium else 5
503+ return types .VertexAISearch (
504+ datastore = self .data_store_id ,
505+ engine = self .search_engine_id ,
506+ filter = self .filter ,
507+ max_results = dynamic_max_results ,
508+ )
509+
510+ tool = DynamicMaxResultsSearchTool (
511+ data_store_id = 'test_data_store' , max_results = 10
512+ )
513+ tool_context = await _create_tool_context ()
514+ tool_context .state ['is_premium' ] = True
515+
516+ llm_request = LlmRequest (
517+ model = 'gemini-2.5-pro' , config = types .GenerateContentConfig ()
518+ )
519+
520+ await tool .process_llm_request (
521+ tool_context = tool_context , llm_request = llm_request
522+ )
523+
524+ retrieval_tool = llm_request .config .tools [0 ]
525+ # Verify max_results was dynamically set to premium value
526+ assert retrieval_tool .retrieval .vertex_ai_search .max_results == 20
527+
528+ @pytest .mark .asyncio
529+ async def test_subclass_receives_readonly_context (self ):
530+ """Test that subclass receives the context correctly."""
531+ received_contexts = []
532+
533+ class ContextCapturingSearchTool (VertexAiSearchTool ):
534+ """Custom search tool that captures the context."""
535+
536+ def _build_vertex_ai_search_config (self , ctx ):
537+ received_contexts .append (ctx )
538+ return types .VertexAISearch (
539+ datastore = self .data_store_id ,
540+ engine = self .search_engine_id ,
541+ filter = self .filter ,
542+ max_results = self .max_results ,
543+ )
544+
545+ tool = ContextCapturingSearchTool (data_store_id = 'test_data_store' )
546+ tool_context = await _create_tool_context ()
547+
548+ llm_request = LlmRequest (
549+ model = 'gemini-2.5-pro' , config = types .GenerateContentConfig ()
550+ )
551+
552+ await tool .process_llm_request (
553+ tool_context = tool_context , llm_request = llm_request
554+ )
555+
556+ # Verify the context was passed to _build_vertex_ai_search_config
557+ assert len (received_contexts ) == 1
558+ assert received_contexts [0 ] is tool_context
0 commit comments