From 7f964d5625bea84fada767efc32661db34473a80 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 11 Jun 2025 12:46:04 -0700 Subject: [PATCH 01/24] feat: Export global quota configs in preview sdk PiperOrigin-RevId: 770276336 --- .../vertex_rag/test_rag_constants_preview.py | 19 ++++++++++++++++ .../unit/vertex_rag/test_rag_data_preview.py | 22 +++++++++++++++++++ vertexai/preview/rag/rag_data.py | 18 +++++++++++++++ vertexai/preview/rag/utils/_gapic_utils.py | 14 +++++++++++- vertexai/preview/rag/utils/resources.py | 9 ++++++++ 5 files changed, 81 insertions(+), 1 deletion(-) diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index cd4fb6a107..08a8896f79 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -485,10 +485,12 @@ TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig( rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, rebuild_ann_index=False, + max_embedding_requests_per_min=1000, ) TEST_IMPORT_FILES_CONFIG_GCS_REBUILD_ANN_INDEX = ImportRagFilesConfig( rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, rebuild_ann_index=True, + max_embedding_requests_per_min=1000, ) TEST_IMPORT_FILES_CONFIG_GCS_REBUILD_ANN_INDEX.gcs_source.uris = [TEST_GCS_PATH] TEST_IMPORT_FILES_CONFIG_GCS_REBUILD_ANN_INDEX.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = ( @@ -517,6 +519,7 @@ TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig( rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, rebuild_ann_index=False, + max_embedding_requests_per_min=1000, ) TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [ GoogleDriveSource.ResourceId( @@ -530,6 +533,7 @@ TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig( rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, rebuild_ann_index=False, + max_embedding_requests_per_min=1000, ) TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [ GoogleDriveSource.ResourceId( @@ -589,6 +593,7 @@ ) ), rebuild_ann_index=False, + max_embedding_requests_per_min=1000, ) TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800 @@ -603,6 +608,14 @@ import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FILE, ) +TEST_IMPORT_REQUEST_DRIVE_FILE_GLOBAL_QUOTA_CONTROL = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FILE, +) +TEST_IMPORT_REQUEST_DRIVE_FILE_GLOBAL_QUOTA_CONTROL.import_rag_files_config.global_max_embedding_requests_per_min = ( + 8000 +) + TEST_IMPORT_RESPONSE = ImportRagFilesResponse(imported_rag_files_count=2) TEST_GAPIC_RAG_FILE = GapicRagFile( @@ -649,6 +662,7 @@ rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG, rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, rebuild_ann_index=False, + max_embedding_requests_per_min=1000, ) TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [ GapicSlackSource.SlackChannels( @@ -703,6 +717,7 @@ rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG, rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, rebuild_ann_index=False, + max_embedding_requests_per_min=1000, ) TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [ GapicJiraSource.JiraQueries( @@ -736,6 +751,7 @@ TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE = ImportRagFilesConfig( rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG, rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, + max_embedding_requests_per_min=1000, share_point_sources=GapicSharePointSources( share_point_sources=[ GapicSharePointSources.SharePointSource( @@ -813,6 +829,7 @@ TEST_LAYOUT_PARSER_WITH_PROCESSOR_PATH_CONFIG = LayoutParserConfig( processor_name="projects/test-project/locations/us/processors/abc123", max_parsing_requests_per_min=100, + global_max_parsing_requests_per_min=1000, ) TEST_LAYOUT_PARSER_WITH_PROCESSOR_VERSION_PATH_CONFIG = LayoutParserConfig( @@ -885,6 +902,7 @@ TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesConfig( rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG, + max_embedding_requests_per_min=1000, share_point_sources=GapicSharePointSources( share_point_sources=[ GapicSharePointSources.SharePointSource( @@ -914,6 +932,7 @@ layout_parser=RagFileParsingConfig.LayoutParser( processor_name="projects/test-project/locations/us/processors/abc123", max_parsing_requests_per_min=100, + global_max_parsing_requests_per_min=1000, ) ) ) diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index c1661adff5..cdb4c5823e 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -659,6 +659,14 @@ def import_files_request_eq(returned_request, expected_request): returned_request.import_rag_files_config.rebuild_ann_index == expected_request.import_rag_files_config.rebuild_ann_index ) + assert ( + returned_request.import_rag_files_config.max_embedding_requests_per_min + == expected_request.import_rag_files_config.max_embedding_requests_per_min + ) + assert ( + returned_request.import_rag_files_config.global_max_embedding_requests_per_min + == expected_request.import_rag_files_config.global_max_embedding_requests_per_min + ) def rag_engine_config_eq(returned_config, expected_config): @@ -1349,6 +1357,20 @@ def test_prepare_import_files_request_drive_files(self): request, test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FILE ) + def test_prepare_import_files_request_drive_files_with_global_quota_control(self): + paths = [test_rag_constants_preview.TEST_DRIVE_FILE] + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=paths, + transformation_config=create_transformation_config(), + max_embedding_requests_per_min=800, + global_max_embedding_requests_per_min=8000, + ) + import_files_request_eq( + request, + test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FILE_GLOBAL_QUOTA_CONTROL, + ) + def test_prepare_import_files_request_invalid_drive_path(self): with pytest.raises(ValueError) as e: paths = ["https://drive.google.com/bslalsdfk/whichever_file/456"] diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index b074f52dd1..7ffe6595a7 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -496,6 +496,7 @@ def import_files( transformation_config: Optional[TransformationConfig] = None, timeout: int = 600, max_embedding_requests_per_min: int = 1000, + global_max_embedding_requests_per_min: Optional[int] = None, use_advanced_pdf_parsing: Optional[bool] = False, partial_failures_sink: Optional[str] = None, layout_parser: Optional[LayoutParserConfig] = None, @@ -605,6 +606,13 @@ def import_files( page on the project to set an appropriate value here. If unspecified, a default value of 1,000 QPM would be used. + global_max_embedding_requests_per_min: + Optional. The max number of queries per minute that the indexing + pipeline job is allowed to make to the embedding model specified in + the project. Please follow the quota usage guideline of the embedding + model you use to set the value properly. If this value is not specified, + max_embedding_requests_per_min will be used by indexing pipeline job + as the global limit and this means parallel import jobs are not allowed. timeout: Default is 600 seconds. use_advanced_pdf_parsing: Whether to use advanced PDF parsing on uploaded files. This field is deprecated. @@ -663,6 +671,7 @@ def import_files( chunk_overlap=chunk_overlap, transformation_config=transformation_config, max_embedding_requests_per_min=max_embedding_requests_per_min, + global_max_embedding_requests_per_min=global_max_embedding_requests_per_min, use_advanced_pdf_parsing=use_advanced_pdf_parsing, partial_failures_sink=partial_failures_sink, layout_parser=layout_parser, @@ -686,6 +695,7 @@ async def import_files_async( chunk_overlap: int = 200, transformation_config: Optional[TransformationConfig] = None, max_embedding_requests_per_min: int = 1000, + global_max_embedding_requests_per_min: Optional[int] = None, use_advanced_pdf_parsing: Optional[bool] = False, partial_failures_sink: Optional[str] = None, layout_parser: Optional[LayoutParserConfig] = None, @@ -796,6 +806,13 @@ async def import_files_async( page on the project to set an appropriate value here. If unspecified, a default value of 1,000 QPM would be used. + global_max_embedding_requests_per_min: + Optional. The max number of queries per minute that the indexing + pipeline job is allowed to make to the embedding model specified in + the project. Please follow the quota usage guideline of the embedding + model you use to set the value properly. If this value is not specified, + max_embedding_requests_per_min will be used by indexing pipeline job + as the global limit and this means parallel import jobs are not allowed. use_advanced_pdf_parsing: Whether to use advanced PDF parsing on uploaded files. partial_failures_sink: Either a GCS path to store partial failures or a @@ -852,6 +869,7 @@ async def import_files_async( chunk_overlap=chunk_overlap, transformation_config=transformation_config, max_embedding_requests_per_min=max_embedding_requests_per_min, + global_max_embedding_requests_per_min=global_max_embedding_requests_per_min, use_advanced_pdf_parsing=use_advanced_pdf_parsing, partial_failures_sink=partial_failures_sink, layout_parser=layout_parser, diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index a36799ddc6..5e99e5a59d 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -537,6 +537,7 @@ def prepare_import_files_request( chunk_overlap: int = 200, transformation_config: Optional[TransformationConfig] = None, max_embedding_requests_per_min: int = 1000, + global_max_embedding_requests_per_min: Optional[int] = None, use_advanced_pdf_parsing: bool = False, partial_failures_sink: Optional[str] = None, layout_parser: Optional[LayoutParserConfig] = None, @@ -569,8 +570,15 @@ def prepare_import_files_request( ) rag_file_parsing_config.layout_parser = RagFileParsingConfig.LayoutParser( processor_name=layout_parser.processor_name, - max_parsing_requests_per_min=layout_parser.max_parsing_requests_per_min, ) + if layout_parser.max_parsing_requests_per_min is not None: + rag_file_parsing_config.layout_parser.max_parsing_requests_per_min = ( + layout_parser.max_parsing_requests_per_min + ) + if layout_parser.global_max_parsing_requests_per_min is not None: + rag_file_parsing_config.layout_parser.global_max_parsing_requests_per_min = ( + layout_parser.global_max_parsing_requests_per_min + ) if llm_parser is not None: rag_file_parsing_config.llm_parser = RagFileParsingConfig.LlmParser( model_name=llm_parser.model_name @@ -609,6 +617,10 @@ def prepare_import_files_request( rebuild_ann_index=rebuild_ann_index, ) + if global_max_embedding_requests_per_min is not None: + import_rag_files_config.global_max_embedding_requests_per_min = ( + global_max_embedding_requests_per_min + ) if source is not None: gapic_source = convert_source_for_rag_import(source) if isinstance(gapic_source, GapicSlackSource): diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index f3c1cbe350..bf8e8bffff 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -515,10 +515,19 @@ class LayoutParserConfig: https://cloud.google.com/document-ai/quotas and the Quota page for your project to set an appropriate value here. If unspecified, a default value of 120 QPM will be used. + global_max_parsing_requests_per_min (int): + The maximum number of requests the job is allowed to make to + the Document AI processor per minute in this project. + Consult https://cloud.google.com/document-ai/quotas and the + Quota page for your project to set an appropriate value + here. If this value is not specified, + max_parsing_requests_per_min will be used by indexing + pipeline as the global limit. """ processor_name: str max_parsing_requests_per_min: Optional[int] = None + global_max_parsing_requests_per_min: Optional[int] = None @dataclasses.dataclass From 84895b6c6cd8d898d8472f0a1ace12a8b420717b Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 11 Jun 2025 12:56:14 -0700 Subject: [PATCH 02/24] fix: Use none check to avoid 30s delay in agent run. PiperOrigin-RevId: 770280829 --- vertexai/preview/reasoning_engines/templates/adk.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 34fbcf8757..06a7ff01ef 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -627,15 +627,17 @@ def _asyncio_thread_main(): asyncio.run(_invoke_agent_async()) except RuntimeError as e: event_queue.put(e) + finally: + # Use None as a sentinel to stop the main thread. + event_queue.put(None) thread = threading.Thread(target=_asyncio_thread_main) thread.start() try: while True: - try: - event = event_queue.get(timeout=30) - except queue.Empty: + event = event_queue.get() + if event is None: break if isinstance(event, RuntimeError): raise event From 267b53d4a7db87cdf70181f76adb5c6980a2136a Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 11 Jun 2025 14:57:36 -0700 Subject: [PATCH 03/24] feat: Add PSC interface config support for Custom Training Jobs PiperOrigin-RevId: 770333532 --- google/cloud/aiplatform/jobs.py | 24 ++++ google/cloud/aiplatform/training_jobs.py | 63 +++++++++ tests/unit/aiplatform/constants.py | 1 + tests/unit/aiplatform/test_custom_job.py | 148 ++++++++++++++++++++ tests/unit/aiplatform/test_training_jobs.py | 124 ++++++++++++++++ 5 files changed, 360 insertions(+) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 9651200643..73f5997b6d 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -45,6 +45,7 @@ model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat, job_state_v1beta1 as gca_job_state_v1beta1, model_monitoring_v1beta1 as gca_model_monitoring_v1beta1, + service_networking as gca_service_networking, ) # TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA from google.cloud.aiplatform.constants import base as constants @@ -2236,6 +2237,9 @@ def run( persistent_resource_id: Optional[str] = None, scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> None: """Run this configured CustomJob. @@ -2310,6 +2314,9 @@ def run( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 1 day. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. """ network = network or initializer.global_config.network service_account = service_account or initializer.global_config.service_account @@ -2329,6 +2336,7 @@ def run( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) @base.optional_sync() @@ -2348,6 +2356,9 @@ def _run( persistent_resource_id: Optional[str] = None, scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> None: """Helper method to ensure network synchronization and to run the configured CustomJob. @@ -2420,6 +2431,9 @@ def _run( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 1 day. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. """ self.submit( service_account=service_account, @@ -2435,6 +2449,7 @@ def _run( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) self._block_until_complete() @@ -2455,6 +2470,9 @@ def submit( persistent_resource_id: Optional[str] = None, scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> None: """Submit the configured CustomJob. @@ -2524,6 +2542,9 @@ def submit( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 1 day. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. Raises: ValueError: @@ -2546,6 +2567,9 @@ def submit( if network: self._gca_resource.job_spec.network = network + if psc_interface_config: + self._gca_resource.job_spec.psc_interface_config = psc_interface_config + if ( timeout or restart_job_on_worker_restart diff --git a/google/cloud/aiplatform/training_jobs.py b/google/cloud/aiplatform/training_jobs.py index 056a3fe5dc..9a761b2935 100644 --- a/google/cloud/aiplatform/training_jobs.py +++ b/google/cloud/aiplatform/training_jobs.py @@ -45,6 +45,7 @@ training_pipeline as gca_training_pipeline, study as gca_study_compat, custom_job as gca_custom_job_compat, + service_networking as gca_service_networking, ) from google.cloud.aiplatform.utils import _timestamped_gcs_dir @@ -1553,6 +1554,9 @@ def _prepare_training_task_inputs_and_output_dir( persistent_resource_id: Optional[str] = None, scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Tuple[Dict, str]: """Prepares training task inputs and output directory for custom job. @@ -1617,6 +1621,8 @@ def _prepare_training_task_inputs_and_output_dir( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. The PSC interface config for the job. Returns: Training task inputs and Output directory for custom job. """ @@ -1645,6 +1651,8 @@ def _prepare_training_task_inputs_and_output_dir( training_task_inputs["enable_dashboard_access"] = enable_dashboard_access if persistent_resource_id: training_task_inputs["persistent_resource_id"] = persistent_resource_id + if psc_interface_config: + training_task_inputs["psc_interface_config"] = psc_interface_config if ( timeout @@ -3055,6 +3063,9 @@ def run( reservation_affinity_key: Optional[str] = None, reservation_affinity_values: Optional[List[str]] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -3433,6 +3444,9 @@ def run( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. Returns: The trained Vertex AI model resource or None if the training @@ -3504,6 +3518,7 @@ def run( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) def submit( @@ -3564,6 +3579,9 @@ def submit( reservation_affinity_key: Optional[str] = None, reservation_affinity_values: Optional[List[str]] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Optional[models.Model]: """Submits the custom training job without blocking until completion. @@ -3887,6 +3905,9 @@ def submit( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -3958,6 +3979,7 @@ def submit( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -4007,6 +4029,9 @@ def _run( persistent_resource_id: Optional[str] = None, scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. @@ -4209,6 +4234,8 @@ def _run( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. The PSC interface config for the job. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -4265,6 +4292,7 @@ def _run( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) model = self._run_job( @@ -4596,6 +4624,9 @@ def run( reservation_affinity_key: Optional[str] = None, reservation_affinity_values: Optional[List[str]] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -4912,6 +4943,9 @@ def run( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -4982,6 +5016,7 @@ def run( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) def submit( @@ -5042,6 +5077,9 @@ def submit( reservation_affinity_key: Optional[str] = None, reservation_affinity_values: Optional[List[str]] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Optional[models.Model]: """Submits the custom training job without blocking until completion. @@ -5358,6 +5396,9 @@ def submit( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -5428,6 +5469,7 @@ def submit( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -5476,6 +5518,9 @@ def _run( persistent_resource_id: Optional[str] = None, scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. Args: @@ -5674,6 +5719,9 @@ def _run( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -5724,6 +5772,7 @@ def _run( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) model = self._run_job( @@ -7755,6 +7804,9 @@ def run( reservation_affinity_key: Optional[str] = None, reservation_affinity_values: Optional[List[str]] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Optional[models.Model]: """Runs the custom training job. @@ -8072,6 +8124,9 @@ def run( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -8137,6 +8192,7 @@ def run( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) @base.optional_sync(construct_object_on_arg="managed_model") @@ -8184,6 +8240,9 @@ def _run( persistent_resource_id: Optional[str] = None, scheduling_strategy: Optional[gca_custom_job_compat.Scheduling.Strategy] = None, max_wait_duration: Optional[int] = None, + psc_interface_config: Optional[ + gca_service_networking.PscInterfaceConfig + ] = None, ) -> Optional[models.Model]: """Packages local script and launches training_job. @@ -8367,6 +8426,9 @@ def _run( This is the maximum duration that a job will wait for the requested resources to be provisioned in seconds. If set to 0, the job will wait indefinitely. The default is 30 minutes. + psc_interface_config (gca_service_networking.PscInterfaceConfig): + Optional. Configuration for Private Service Connect interface + used for training. Returns: model: The trained Vertex AI Model resource or None if training did not @@ -8417,6 +8479,7 @@ def _run( persistent_resource_id=persistent_resource_id, scheduling_strategy=scheduling_strategy, max_wait_duration=max_wait_duration, + psc_interface_config=psc_interface_config, ) model = self._run_job( diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py index 633659b449..1074b3f94f 100644 --- a/tests/unit/aiplatform/constants.py +++ b/tests/unit/aiplatform/constants.py @@ -222,6 +222,7 @@ class TrainingJobConstants: ) _TEST_DEFAULT_ENCRYPTION_KEY_NAME = "key_default" _TEST_SPOT_STRATEGY = custom_job.Scheduling.Strategy.SPOT + _TEST_PSC_INTERFACE_CONFIG = {"network_attachment": "network_attachment_value"} def create_tpu_job_proto(tpu_version): worker_pool_spec = ( diff --git a/tests/unit/aiplatform/test_custom_job.py b/tests/unit/aiplatform/test_custom_job.py index 6bd1bca9e7..75db082ea3 100644 --- a/tests/unit/aiplatform/test_custom_job.py +++ b/tests/unit/aiplatform/test_custom_job.py @@ -65,6 +65,9 @@ ) _TEST_PREBUILT_CONTAINER_IMAGE = "gcr.io/cloud-aiplatform/container:image" _TEST_SPOT_STRATEGY = test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY +_TEST_PSC_INTERFACE_CONFIG = ( + test_constants.TrainingJobConstants._TEST_PSC_INTERFACE_CONFIG +) _TEST_RUN_ARGS = test_constants.TrainingJobConstants._TEST_RUN_ARGS _TEST_EXPERIMENT = "test-experiment" @@ -248,6 +251,12 @@ def _get_custom_job_proto_with_spot_strategy(state=None, name=None, error=None): return custom_job_proto +def _get_custom_job_proto_with_psc_interface_config(state=None, name=None, error=None): + custom_job_proto = _get_custom_job_proto(state=state, name=name, error=error) + custom_job_proto.job_spec.psc_interface_config = _TEST_PSC_INTERFACE_CONFIG + return custom_job_proto + + @pytest.fixture def mock_builtin_open(): with patch("builtins.open", mock_open(read_data="data")) as mock_file: @@ -462,6 +471,29 @@ def get_custom_job_mock_with_spot_strategy(): yield get_custom_job_mock +@pytest.fixture +def get_custom_job_mock_with_psc_interface_config(): + """Fixture for mocking get_custom_job with psc interface config.""" + with patch.object( + job_service_client.JobServiceClient, "get_custom_job" + ) as get_custom_job_mock: + get_custom_job_mock.side_effect = [ + _get_custom_job_proto_with_psc_interface_config( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ), + _get_custom_job_proto_with_psc_interface_config( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_RUNNING, + ), + _get_custom_job_proto_with_psc_interface_config( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED, + ), + ] + yield get_custom_job_mock + + @pytest.fixture def create_custom_job_mock(): with mock.patch.object( @@ -523,6 +555,20 @@ def create_custom_job_mock_with_spot_strategy(): yield create_custom_job_mock +@pytest.fixture +def create_custom_job_mock_with_psc_interface_config(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_custom_job" + ) as create_custom_job_mock: + create_custom_job_mock.return_value = ( + _get_custom_job_proto_with_psc_interface_config( + name=_TEST_CUSTOM_JOB_NAME, + state=gca_job_state_compat.JobState.JOB_STATE_PENDING, + ) + ) + yield create_custom_job_mock + + _EXPERIMENT_MOCK = copy.deepcopy(_EXPERIMENT_MOCK) _EXPERIMENT_MOCK.metadata[ constants._BACKING_TENSORBOARD_RESOURCE_KEY @@ -1667,3 +1713,105 @@ def test_create_custom_job_with_spot_strategy( assert ( job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED ) + + def test_create_custom_job_with_psc_interface_config( + self, + create_custom_job_mock_with_psc_interface_config, + get_custom_job_mock_with_psc_interface_config, + ): + """Tests creating a custom job with psc interface config.""" + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = aiplatform.CustomJob( + display_name=_TEST_DISPLAY_NAME, + worker_pool_specs=_TEST_WORKER_POOL_SPEC, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + labels=_TEST_LABELS, + ) + + job.run( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + max_wait_duration=_TEST_MAX_WAIT_DURATION, + psc_interface_config=_TEST_PSC_INTERFACE_CONFIG, + ) + + job.wait_for_resource_creation() + + job.wait() + + assert job.resource_name == _TEST_CUSTOM_JOB_NAME + + expected_custom_job = _get_custom_job_proto_with_psc_interface_config() + + create_custom_job_mock_with_psc_interface_config.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + assert job.job_spec == expected_custom_job.job_spec + assert ( + job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_SUCCEEDED + ) + + def test_submit_custom_job_with_psc_interface_config( + self, + create_custom_job_mock_with_psc_interface_config, + get_custom_job_mock_with_psc_interface_config, + ): + """Tests submitting a custom job with psc interface config.""" + + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + staging_bucket=_TEST_STAGING_BUCKET, + encryption_spec_key_name=_TEST_DEFAULT_ENCRYPTION_KEY_NAME, + ) + + job = aiplatform.CustomJob( + display_name=_TEST_DISPLAY_NAME, + worker_pool_specs=_TEST_WORKER_POOL_SPEC, + base_output_dir=_TEST_BASE_OUTPUT_DIR, + labels=_TEST_LABELS, + ) + + job.submit( + service_account=_TEST_SERVICE_ACCOUNT, + network=_TEST_NETWORK, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + max_wait_duration=_TEST_MAX_WAIT_DURATION, + psc_interface_config=_TEST_PSC_INTERFACE_CONFIG, + ) + + job.wait_for_resource_creation() + + assert job.resource_name == _TEST_CUSTOM_JOB_NAME + + job.wait() + + expected_custom_job = _get_custom_job_proto_with_psc_interface_config() + + create_custom_job_mock_with_psc_interface_config.assert_called_once_with( + parent=_TEST_PARENT, + custom_job=expected_custom_job, + timeout=None, + ) + + assert job.job_spec == expected_custom_job.job_spec + assert ( + job._gca_resource.state == gca_job_state_compat.JobState.JOB_STATE_PENDING + ) diff --git a/tests/unit/aiplatform/test_training_jobs.py b/tests/unit/aiplatform/test_training_jobs.py index 08c258ac37..0b98dbdf22 100644 --- a/tests/unit/aiplatform/test_training_jobs.py +++ b/tests/unit/aiplatform/test_training_jobs.py @@ -250,6 +250,9 @@ test_constants.PersistentResourceConstants._TEST_PERSISTENT_RESOURCE_ID ) _TEST_SPOT_STRATEGY = test_constants.TrainingJobConstants._TEST_SPOT_STRATEGY +_TEST_PSC_INTERFACE_CONFIG = ( + test_constants.TrainingJobConstants._TEST_PSC_INTERFACE_CONFIG +) _TEST_BASE_CUSTOM_JOB_PROTO = gca_custom_job.CustomJob( job_spec=gca_custom_job.CustomJobSpec(), @@ -319,6 +322,17 @@ def _get_custom_job_proto_with_spot_strategy(state=None, name=None, version="v1" return custom_job_proto +def _get_custom_job_proto_with_psc_interface_config( + state=None, name=None, version="v1" +): + custom_job_proto = copy.deepcopy(_TEST_BASE_CUSTOM_JOB_PROTO) + custom_job_proto.name = name + custom_job_proto.state = state + + custom_job_proto.job_spec.psc_interface_config = _TEST_PSC_INTERFACE_CONFIG + return custom_job_proto + + def local_copy_method(path): shutil.copy(path, ".") return pathlib.Path(path).name @@ -844,6 +858,21 @@ def make_training_pipeline_with_spot_strategy(state): return training_pipeline +def make_training_pipeline_with_psc_interface_config(state): + training_pipeline = gca_training_pipeline.TrainingPipeline( + name=_TEST_PIPELINE_RESOURCE_NAME, + state=state, + training_task_inputs={ + "psc_interface_config": _TEST_PSC_INTERFACE_CONFIG, + }, + ) + if state == gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING: + training_pipeline.training_task_metadata = { + "backingCustomJob": _TEST_CUSTOM_JOB_RESOURCE_NAME + } + return training_pipeline + + @pytest.fixture def mock_pipeline_service_get(make_call=make_training_pipeline): with mock.patch.object( @@ -1015,6 +1044,35 @@ def mock_pipeline_service_get_with_spot_strategy(): yield mock_get_training_pipeline +@pytest.fixture +def mock_pipeline_service_get_with_psc_interface_config(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "get_training_pipeline" + ) as mock_get_training_pipeline: + mock_get_training_pipeline.side_effect = [ + make_training_pipeline_with_psc_interface_config( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ), + make_training_pipeline_with_psc_interface_config( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_psc_interface_config( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_psc_interface_config( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING, + ), + make_training_pipeline_with_psc_interface_config( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + make_training_pipeline_with_psc_interface_config( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED, + ), + ] + + yield mock_get_training_pipeline + + @pytest.fixture def mock_pipeline_service_cancel(): with mock.patch.object( @@ -1102,6 +1160,19 @@ def mock_pipeline_service_create_with_spot_strategy(): yield mock_create_training_pipeline +@pytest.fixture +def mock_pipeline_service_create_with_psc_interface_config(): + with mock.patch.object( + pipeline_service_client.PipelineServiceClient, "create_training_pipeline" + ) as mock_create_training_pipeline: + mock_create_training_pipeline.return_value = ( + make_training_pipeline_with_psc_interface_config( + state=gca_pipeline_state.PipelineState.PIPELINE_STATE_PENDING, + ) + ) + yield mock_create_training_pipeline + + @pytest.fixture def mock_pipeline_service_get_with_no_model_to_upload(): with mock.patch.object( @@ -2523,6 +2594,59 @@ def test_run_call_pipeline_service_create_with_spot_strategy(self, sync): == _TEST_SPOT_STRATEGY ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures( + "mock_pipeline_service_create_with_psc_interface_config", + "mock_pipeline_service_get_with_psc_interface_config", + "mock_python_package_to_gcs", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_call_pipeline_service_create_with_psc_interface_config(self, sync): + + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = training_jobs.CustomTrainingJob( + display_name=_TEST_DISPLAY_NAME, + script_path=_TEST_LOCAL_SCRIPT_FILE_NAME, + container_uri=_TEST_TRAINING_CONTAINER_IMAGE, + ) + + job.run( + base_output_dir=_TEST_BASE_OUTPUT_DIR, + args=_TEST_RUN_ARGS, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + timeout=_TEST_TIMEOUT, + restart_job_on_worker_restart=_TEST_RESTART_JOB_ON_WORKER_RESTART, + sync=sync, + create_request_timeout=None, + disable_retries=_TEST_DISABLE_RETRIES, + psc_interface_config=_TEST_PSC_INTERFACE_CONFIG, + ) + + if not sync: + job.wait() + + assert job._gca_resource == make_training_pipeline_with_psc_interface_config( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + assert ( + job._gca_resource.state + == gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ) + + assert ( + job._gca_resource.training_task_inputs["psc_interface_config"] + == _TEST_PSC_INTERFACE_CONFIG + ) + @mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1) @mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1) @pytest.mark.usefixtures( From a1f420582908bc3d9a3201d36bf8d075758d4644 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 11 Jun 2025 15:21:12 -0700 Subject: [PATCH 04/24] feat: Enable asia-south2 PiperOrigin-RevId: 770342332 --- google/cloud/aiplatform/constants/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/aiplatform/constants/base.py b/google/cloud/aiplatform/constants/base.py index d3a777da73..fa692b0c63 100644 --- a/google/cloud/aiplatform/constants/base.py +++ b/google/cloud/aiplatform/constants/base.py @@ -28,6 +28,7 @@ "asia-northeast2", "asia-northeast3", "asia-south1", + "asia-south2", "asia-southeast1", "asia-southeast2", "australia-southeast1", From fe474aed162667e6e78870603f87d4bddf3f8b4f Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Thu, 12 Jun 2025 10:25:27 -0700 Subject: [PATCH 05/24] chore: refactor run_inference and bug fixes. PiperOrigin-RevId: 770710817 --- tests/unit/vertexai/genai/test_evals.py | 312 ++++++++++------- vertexai/_genai/_evals_common.py | 439 +++++++++++------------- 2 files changed, 392 insertions(+), 359 deletions(-) diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 15216334e7..a31aa18ddf 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# pylint: disable=protected-access,bad-continuation,missing-function-docstring - import importlib import json import os @@ -558,14 +556,42 @@ def test_inference_with_row_level_config_overrides( ): mock_df = pd.DataFrame( { + "id": [1, 2, 3], "request": [ - json.dumps( - {"prompt": "req 1", "generation_config": {"top_p": 0.5}} - ), - json.dumps( - {"prompt": "req 2", "generation_config": {"top_p": 0.9}} - ), - ] + { + "contents": [ + { + "parts": [{"text": "Placeholder prompt 1"}], + "role": "user", + } + ] + }, + { + "contents": [ + { + "parts": [{"text": "Placeholder prompt 2.1"}], + "role": "user", + }, + { + "parts": [{"text": "Placeholder model response 2.1"}], + "role": "model", + }, + { + "parts": [{"text": "Placeholder prompt 2.2"}], + "role": "user", + }, + ], + "generation_config": {"temperature": 0.7, "top_k": 5}, + }, + { + "contents": [ + { + "parts": [{"text": "Placeholder prompt 3"}], + "role": "user", + } + ], + }, + ], } ) mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict( @@ -576,7 +602,7 @@ def test_inference_with_row_level_config_overrides( candidates=[ genai_types.Candidate( content=genai_types.Content( - parts=[genai_types.Part(text="resp 1")] + parts=[genai_types.Part(text="Placeholder response 1")] ), finish_reason=genai_types.FinishReason.STOP, ) @@ -587,7 +613,18 @@ def test_inference_with_row_level_config_overrides( candidates=[ genai_types.Candidate( content=genai_types.Content( - parts=[genai_types.Part(text="resp 2")] + parts=[genai_types.Part(text="Placeholder response 2")] + ), + finish_reason=genai_types.FinishReason.STOP, + ) + ], + prompt_feedback=None, + ), + genai_types.GenerateContentResponse( + candidates=[ + genai_types.Candidate( + content=genai_types.Content( + parts=[genai_types.Part(text="Placeholder response 3")] ), finish_reason=genai_types.FinishReason.STOP, ) @@ -607,40 +644,68 @@ def test_inference_with_row_level_config_overrides( [ mock.call( model="gemini-pro", - contents=json.dumps( - {"prompt": "req 1", "generation_config": {"top_p": 0.5}} - ), - config=genai_types.GenerateContentConfig(top_p=0.5), + contents=[ + {"parts": [{"text": "Placeholder prompt 1"}], "role": "user"} + ], + config=genai_types.GenerateContentConfig(), ), mock.call( model="gemini-pro", - contents=json.dumps( - {"prompt": "req 2", "generation_config": {"top_p": 0.9}} - ), - config=genai_types.GenerateContentConfig(top_p=0.9), + contents=[ + {"parts": [{"text": "Placeholder prompt 2.1"}], "role": "user"}, + { + "parts": [{"text": "Placeholder model response 2.1"}], + "role": "model", + }, + {"parts": [{"text": "Placeholder prompt 2.2"}], "role": "user"}, + ], + config=genai_types.GenerateContentConfig(temperature=0.7, top_k=5), + ), + mock.call( + model="gemini-pro", + contents=[ + {"parts": [{"text": "Placeholder prompt 3"}], "role": "user"} + ], + config=genai_types.GenerateContentConfig(), ), ], any_order=True, ) + request_obj_1 = { + "contents": [{"parts": [{"text": "Placeholder prompt 1"}], "role": "user"}] + } + request_obj_2 = { + "contents": [ + {"parts": [{"text": "Placeholder prompt 2.1"}], "role": "user"}, + { + "parts": [{"text": "Placeholder model response 2.1"}], + "role": "model", + }, + {"parts": [{"text": "Placeholder prompt 2.2"}], "role": "user"}, + ], + "generation_config": {"temperature": 0.7, "top_k": 5}, + } + request_obj_3 = { + "contents": [{"parts": [{"text": "Placeholder prompt 3"}], "role": "user"}], + } expected_df = pd.DataFrame( { - "request": [ - json.dumps( - {"prompt": "req 1", "generation_config": {"top_p": 0.5}} - ), - json.dumps( - {"prompt": "req 2", "generation_config": {"top_p": 0.9}} - ), + "id": [1, 2, 3], + "request": [request_obj_1, request_obj_2, request_obj_3], + "response": [ + "Placeholder response 1", + "Placeholder response 2", + "Placeholder response 3", ], - "response": ["resp 1", "resp 2"], } ) pd.testing.assert_frame_equal( - inference_result.eval_dataset_df.sort_values(by="request").reset_index( + inference_result.eval_dataset_df.sort_values(by="id").reset_index( drop=True ), - expected_df.sort_values(by="request").reset_index(drop=True), + expected_df.sort_values(by="id").reset_index(drop=True), + check_dtype=False, ) @mock.patch.object(_evals_common, "Models") @@ -2483,9 +2548,8 @@ def my_custom_metric_fn(data: dict): assert summary_metric.mean_score == 0.5 mock_eval_dependencies["mock_evaluate_instances"].assert_not_called() - @mock.patch("vertexai._genai._evals_metric_handlers.LLMMetricHandler.process") def test_llm_metric_default_aggregation_mixed_results( - self, mock_llm_process, mock_api_client_fixture, mock_eval_dependencies + self, mock_api_client_fixture, mock_eval_dependencies ): dataset_df = pd.DataFrame( [ @@ -2501,43 +2565,42 @@ def test_llm_metric_default_aggregation_mixed_results( name="quality", prompt_template="Rate: {response}" ) - mock_llm_process.side_effect = [ - vertexai_genai_types.EvalCaseMetricResult( - metric_name="quality", score=0.8, explanation="Good" - ), - vertexai_genai_types.EvalCaseMetricResult( - metric_name="quality", score=0.6, explanation="Okay" - ), - vertexai_genai_types.EvalCaseMetricResult( - metric_name="quality", error_message="Processing failed" - ), - ] + with mock.patch( + "vertexai._genai._evals_metric_handlers.LLMMetricHandler.process" + ) as mock_llm_process: + mock_llm_process.side_effect = [ + vertexai_genai_types.EvalCaseMetricResult( + metric_name="quality", score=0.8, explanation="Good" + ), + vertexai_genai_types.EvalCaseMetricResult( + metric_name="quality", score=0.6, explanation="Okay" + ), + vertexai_genai_types.EvalCaseMetricResult( + metric_name="quality", error_message="Processing failed" + ), + ] - result = _evals_common._execute_evaluation( - api_client=mock_api_client_fixture, - dataset=input_dataset, - metrics=[llm_metric], - ) + result = _evals_common._execute_evaluation( + api_client=mock_api_client_fixture, + dataset=input_dataset, + metrics=[llm_metric], + ) + + assert mock_llm_process.call_count == 3 + assert len(result.summary_metrics) == 1 + summary = result.summary_metrics[0] + assert summary.metric_name == "quality" + assert summary.num_cases_total == 3 + assert summary.num_cases_valid == 2 + assert summary.num_cases_error == 1 + assert summary.mean_score == pytest.approx(0.7) + assert summary.stdev_score == pytest.approx(statistics.stdev([0.8, 0.6])) - assert mock_llm_process.call_count == 3 - assert len(result.summary_metrics) == 1 - summary = result.summary_metrics[0] - assert summary.metric_name == "quality" - assert summary.num_cases_total == 3 - assert summary.num_cases_valid == 2 - assert summary.num_cases_error == 1 - assert summary.mean_score == pytest.approx(0.7) - assert summary.stdev_score == pytest.approx(statistics.stdev([0.8, 0.6])) - - @mock.patch("vertexai._genai._evals_metric_handlers.LLMMetricHandler.process") def test_llm_metric_custom_aggregation_success( - self, mock_llm_process, mock_api_client_fixture, mock_eval_dependencies + self, mock_api_client_fixture, mock_eval_dependencies ): dataset_df = pd.DataFrame( - [ - {"prompt": "P1", "response": "R1"}, - {"prompt": "P2", "response": "R2"}, - ] + [{"prompt": "P1", "response": "R1"}, {"prompt": "P2", "response": "R2"}] ) input_dataset = vertexai_genai_types.EvaluationDataset( eval_dataset_df=dataset_df @@ -2556,32 +2619,34 @@ def custom_agg_fn(results: list[vertexai_genai_types.EvalCaseMetricResult]): aggregate_summary_fn=custom_agg_fn, ) - mock_llm_process.side_effect = [ - vertexai_genai_types.EvalCaseMetricResult( - metric_name="custom_quality", score=0.8 - ), - vertexai_genai_types.EvalCaseMetricResult( - metric_name="custom_quality", score=0.7 - ), - ] + with mock.patch( + "vertexai._genai._evals_metric_handlers.LLMMetricHandler.process" + ) as mock_llm_process: + mock_llm_process.side_effect = [ + vertexai_genai_types.EvalCaseMetricResult( + metric_name="custom_quality", score=0.8 + ), + vertexai_genai_types.EvalCaseMetricResult( + metric_name="custom_quality", score=0.7 + ), + ] + + result = _evals_common._execute_evaluation( + api_client=mock_api_client_fixture, + dataset=input_dataset, + metrics=[llm_metric], + ) + assert mock_llm_process.call_count == 2 + assert len(result.summary_metrics) == 1 + summary = result.summary_metrics[0] + assert summary.metric_name == "custom_quality" + assert summary.num_cases_total == 2 + assert summary.num_cases_valid == 2 + assert summary.mean_score == 0.75 + assert summary.model_dump(exclude_none=True)["my_custom_stat"] == 123 - result = _evals_common._execute_evaluation( - api_client=mock_api_client_fixture, - dataset=input_dataset, - metrics=[llm_metric], - ) - assert mock_llm_process.call_count == 2 - assert len(result.summary_metrics) == 1 - summary = result.summary_metrics[0] - assert summary.metric_name == "custom_quality" - assert summary.num_cases_total == 2 - assert summary.num_cases_valid == 2 - assert summary.mean_score == 0.75 - assert summary.model_dump(exclude_none=True)["my_custom_stat"] == 123 - - @mock.patch("vertexai._genai._evals_metric_handlers.LLMMetricHandler.process") def test_llm_metric_custom_aggregation_error_fallback( - self, mock_llm_process, mock_api_client_fixture, mock_eval_dependencies + self, mock_api_client_fixture, mock_eval_dependencies ): dataset_df = pd.DataFrame( [{"prompt": "P1", "response": "R1"}, {"prompt": "P2", "response": "R2"}] @@ -2600,31 +2665,33 @@ def custom_agg_fn_error( prompt_template="Rate: {response}", aggregate_summary_fn=custom_agg_fn_error, ) - mock_llm_process.side_effect = [ - vertexai_genai_types.EvalCaseMetricResult( - metric_name="error_fallback_quality", score=0.9 - ), - vertexai_genai_types.EvalCaseMetricResult( - metric_name="error_fallback_quality", score=0.5 - ), - ] - result = _evals_common._execute_evaluation( - api_client=mock_api_client_fixture, - dataset=input_dataset, - metrics=[llm_metric], - ) - assert mock_llm_process.call_count == 2 - summary = result.summary_metrics[0] - assert summary.metric_name == "error_fallback_quality" - assert summary.num_cases_total == 2 - assert summary.num_cases_valid == 2 - assert summary.num_cases_error == 0 - assert summary.mean_score == pytest.approx(0.7) - assert summary.stdev_score == pytest.approx(statistics.stdev([0.9, 0.5])) + with mock.patch( + "vertexai._genai._evals_metric_handlers.LLMMetricHandler.process" + ) as mock_llm_process: + mock_llm_process.side_effect = [ + vertexai_genai_types.EvalCaseMetricResult( + metric_name="error_fallback_quality", score=0.9 + ), + vertexai_genai_types.EvalCaseMetricResult( + metric_name="error_fallback_quality", score=0.5 + ), + ] + result = _evals_common._execute_evaluation( + api_client=mock_api_client_fixture, + dataset=input_dataset, + metrics=[llm_metric], + ) + assert mock_llm_process.call_count == 2 + summary = result.summary_metrics[0] + assert summary.metric_name == "error_fallback_quality" + assert summary.num_cases_total == 2 + assert summary.num_cases_valid == 2 + assert summary.num_cases_error == 0 + assert summary.mean_score == pytest.approx(0.7) + assert summary.stdev_score == pytest.approx(statistics.stdev([0.9, 0.5])) - @mock.patch("vertexai._genai._evals_metric_handlers.LLMMetricHandler.process") def test_llm_metric_custom_aggregation_invalid_return_type_fallback( - self, mock_llm_process, mock_api_client_fixture, mock_eval_dependencies + self, mock_api_client_fixture, mock_eval_dependencies ): dataset_df = pd.DataFrame([{"prompt": "P1", "response": "R1"}]) input_dataset = vertexai_genai_types.EvaluationDataset( @@ -2641,17 +2708,20 @@ def custom_agg_fn_invalid_type( prompt_template="Rate: {response}", aggregate_summary_fn=custom_agg_fn_invalid_type, ) - mock_llm_process.return_value = vertexai_genai_types.EvalCaseMetricResult( - metric_name="invalid_type_fallback", score=0.8 - ) - result = _evals_common._execute_evaluation( - api_client=mock_api_client_fixture, - dataset=input_dataset, - metrics=[llm_metric], - ) - summary = result.summary_metrics[0] - assert summary.mean_score == 0.8 - assert summary.num_cases_valid == 1 + with mock.patch( + "vertexai._genai._evals_metric_handlers.LLMMetricHandler.process" + ) as mock_llm_process: + mock_llm_process.return_value = vertexai_genai_types.EvalCaseMetricResult( + metric_name="invalid_type_fallback", score=0.8 + ) + result = _evals_common._execute_evaluation( + api_client=mock_api_client_fixture, + dataset=input_dataset, + metrics=[llm_metric], + ) + summary = result.summary_metrics[0] + assert summary.mean_score == 0.8 + assert summary.num_cases_valid == 1 def test_execute_evaluation_lazy_loaded_prebuilt_metric_instance( self, mock_api_client_fixture, mock_eval_dependencies diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index 812411c523..a4e1564f87 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -128,29 +128,29 @@ def _generate_content_with_retry( def _build_generate_content_config( request_dict: dict[str, Any], - config: Optional[genai_types.GenerateContentConfig] = None, + global_config: Optional[genai_types.GenerateContentConfig] = None, ) -> genai_types.GenerateContentConfig: """Builds a GenerateContentConfig from the request dictionary or provided config.""" - if config: - # If a global config is provided, use it. - # User can still override parts of it if request_dict contains config fields. - merged_config_dict = config.model_dump(exclude_none=True) + if global_config: + # If a global config is provided, apply it as a base config. Parts of + # the global config can be overridden by providing configs in the + # request. + merged_config_dict = global_config.model_dump(exclude_none=True) else: merged_config_dict = {} - # Overlay or set fields from request_dict - if "system_instruction" in request_dict: - merged_config_dict["system_instruction"] = request_dict["system_instruction"] - if "tools" in request_dict: - merged_config_dict["tools"] = request_dict["tools"] - if "tools_config" in request_dict: # Corrected variable name - merged_config_dict["tools_config"] = request_dict["tools_config"] - if "safety_settings" in request_dict: - merged_config_dict["safety_settings"] = request_dict["safety_settings"] + for key in [ + "system_instruction", + "tools", + "tools_config", + "safety_settings", + "labels", + ]: + if key in request_dict: + merged_config_dict[key] = request_dict[key] if "generation_config" in request_dict and isinstance( request_dict["generation_config"], dict ): - # Merge generation_config dict into the main config dict merged_config_dict.update(request_dict["generation_config"]) if "labels" in request_dict: merged_config_dict["labels"] = request_dict["labels"] @@ -158,24 +158,40 @@ def _build_generate_content_config( return genai_types.GenerateContentConfig(**merged_config_dict) -def _run_gemini_inference( +def _extract_contents_for_inference( + request_dict_or_raw_text: Any, +) -> Any: + """Extracts contents from a request dictionary or returns the raw text.""" + if not request_dict_or_raw_text: + raise ValueError("Prompt cannot be empty.") + if isinstance(request_dict_or_raw_text, dict): + contents_for_fn = request_dict_or_raw_text.get("contents", None) + if not contents_for_fn: + raise ValueError("Contents in the request cannot be empty.") + return contents_for_fn + else: + return request_dict_or_raw_text + + +def _execute_inference_concurrently( api_client: BaseApiClient, - model: str, + model_or_fn: Union[str, Callable[[Any], Any]], prompt_dataset: "pd.DataFrame", - config: Optional[genai_types.GenerateContentConfig] = None, + progress_desc: str, + gemini_config: Optional[genai_types.GenerateContentConfig] = None, + inference_fn: Optional[Callable[[Any, Any, Any, Any], Any]] = None, ) -> list[Union[genai_types.GenerateContentResponse, dict[str, Any]]]: - """Internal helper to run inference using Gemini model with concurrency.""" + """Internal helper to run inference with concurrency.""" logger.info( - "Generating responses for %d prompts using model: %s", + "Generating responses for %d prompts using model or function: %s", len(prompt_dataset), - model, + model_or_fn, ) responses: list[ Union[genai_types.GenerateContentResponse, dict[str, Any], None] ] = [None] * len(prompt_dataset) tasks = [] - # Determine the primary column for prompts primary_prompt_column = ( "request" if "request" in prompt_dataset.columns else "prompt" ) @@ -185,70 +201,36 @@ def _run_gemini_inference( f" Found: {prompt_dataset.columns.tolist()}" ) - with tqdm(total=len(prompt_dataset), desc="Gemini Inference") as pbar: + with tqdm(total=len(prompt_dataset), desc=progress_desc) as pbar: with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: for index, row in prompt_dataset.iterrows(): - request_dict = {} - contents_input = row[primary_prompt_column] - - if isinstance(contents_input, str): - try: - # Attempt to parse if it's a JSON string representing a complex request - parsed_json = json.loads(contents_input) - if isinstance(parsed_json, dict): - request_dict = parsed_json - contents = request_dict.get("contents", None) - if ( - contents is None - ): # If 'contents' not in JSON, assume whole string was the content - contents = contents_input - else: # Parsed to something other than dict (e.g. just a string literal in JSON) - contents = contents_input - except json.JSONDecodeError: - # It's a raw text prompt string - contents = contents_input - elif isinstance(contents_input, dict) or isinstance( - contents_input, list - ): - # Already in a structure that could be 'contents' or part of a larger request - contents = contents_input # Assume this is the 'contents' part - # To extract other configs, we'd need a clearer contract on how full requests are passed in rows - # For now, assume if it's dict/list, it's directly the 'contents' - else: - logger.error( - f"Unsupported type for prompt/request column at index {index}:" - f" {type(contents_input)}" + request_dict_or_raw_text = row[primary_prompt_column] + try: + contents = _extract_contents_for_inference(request_dict_or_raw_text) + except ValueError as e: + error_message = ( + f"Failed to extract contents for prompt at index {index}: {e}. " + "Skipping prompt." ) - responses[index] = { - "error": ( - f"Unsupported prompt/request type: {type(contents_input)}" - ) - } + logger.error(error_message) + responses[index] = {"error": error_message} pbar.update(1) continue - if contents is None: - logger.error( - f"Could not extract 'contents' for inference at index {index}." - f" Row data: {row[primary_prompt_column]}" + if isinstance(model_or_fn, str): + generation_content_config = _build_generate_content_config( + request_dict_or_raw_text, + gemini_config, ) - responses[index] = { - "error": "Could not extract 'contents' for inference." - } - pbar.update(1) - continue - - generation_content_config = _build_generate_content_config( - request_dict, - config, - ) - future = executor.submit( - _generate_content_with_retry, - api_client=api_client, - model=model, - contents=contents, - config=generation_content_config, - ) + future = executor.submit( + inference_fn, + api_client=api_client, + model=model_or_fn, + contents=contents, + config=generation_content_config, + ) + else: + future = executor.submit(model_or_fn, contents) future.add_done_callback(lambda _: pbar.update(1)) tasks.append((future, index)) @@ -257,142 +239,63 @@ def _run_gemini_inference( result = future.result() responses[index] = result except Exception as e: - logger.error("Error processing prompt at index %d: %s", index, e) - responses[index] = {"error": f"Gemini Inference task failed: {e}"} + logger.error( + "Error processing prompt at index %d: %s", + index, + e, + ) + responses[index] = {"error": f"Inference task failed: {e}"} return responses +def _run_gemini_inference( + api_client: BaseApiClient, + model: str, + prompt_dataset: "pd.DataFrame", + config: Optional[genai_types.GenerateContentConfig] = None, +) -> list[Union[genai_types.GenerateContentResponse, dict[str, Any]]]: + """Internal helper to run inference using Gemini model with concurrency.""" + return _execute_inference_concurrently( + api_client=api_client, + model_or_fn=model, + prompt_dataset=prompt_dataset, + progress_desc="Gemini Inference", + gemini_config=config, + inference_fn=_generate_content_with_retry, + ) + + def _run_custom_inference( model_fn: Callable[[Any], Any], prompt_dataset: pd.DataFrame, ) -> list[Any]: """Internal helper to run inference using a custom function with concurrency.""" - logger.info( - "Generating responses for %d prompts using custom function.", - len(prompt_dataset), + return _execute_inference_concurrently( + api_client=None, + model_or_fn=model_fn, + prompt_dataset=prompt_dataset, + progress_desc="Custom Inference", ) - responses: list[Union[Any, None]] = [None] * len(prompt_dataset) - tasks = [] - - # Determine the primary column for prompts - if "prompt" in prompt_dataset.columns: - primary_prompt_column = "prompt" - elif "request" in prompt_dataset.columns: - primary_prompt_column = "request" - else: - raise ValueError("Dataset must contain either 'prompt' or 'request'.") - - with tqdm(total=len(prompt_dataset), desc="Custom Inference") as pbar: - with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: - for index, row in prompt_dataset.iterrows(): - contents_input = row[primary_prompt_column] - - # For custom functions, we pass the content as is, assuming the function knows how to handle it. - if isinstance(contents_input, str): - try: - maybe_json = json.loads(contents_input) - # If it's a dict and has 'contents', pass that, else pass the parsed object - if isinstance(maybe_json, dict) and "contents" in maybe_json: - contents_for_fn = maybe_json["contents"] - else: - contents_for_fn = maybe_json - except json.JSONDecodeError: - contents_for_fn = contents_input # Pass as string - else: - contents_for_fn = contents_input - - future = executor.submit(model_fn, contents_for_fn) - future.add_done_callback(lambda _: pbar.update(1)) - tasks.append((future, index)) - - for future, index in tasks: - try: - result = future.result() - responses[index] = result - except Exception as e: - logger.error("Error processing prompt at index %d: %s", index, e) - responses[index] = {"error": f"Custom Inference task failed: {e}"} - return responses - - -def _load_dataframe( - api_client: BaseApiClient, src: Union[str, pd.DataFrame] -) -> pd.DataFrame: - """Loads and prepares the prompt dataset for inference.""" - logger.info("Loading prompt dataset from: %s", src) - try: - loader = _evals_utils.EvalDatasetLoader(api_client=api_client) - dataset_list_of_dicts = loader.load(src) - df = pd.DataFrame(dataset_list_of_dicts) - except Exception as e: - logger.error("Failed to load prompt dataset from source: %s. Error: %s", src, e) - raise e - return df - - -def _apply_prompt_template( - df: pd.DataFrame, prompt_template: types.PromptTemplate -) -> None: - """Applies a prompt template to a DataFrame. - - The DataFrame is expected to have columns corresponding to the variables - in the prompt_template_str. The result will be in a new 'request' column. - - Args: - df: The input DataFrame to modify. - prompt_template: The prompt template to apply. - - Returns: - None. The DataFrame is modified in place. - """ - missing_vars = [var for var in prompt_template.variables if var not in df.columns] - if missing_vars: - raise ValueError( - "Missing columns in DataFrame for prompt template variables:" - f" {', '.join(missing_vars)}. Available columns:" - f" {', '.join(df.columns.tolist())}" - ) - - if "prompt" in df.columns: - logger.info( - "Templated prompts stored in 'request' and will be used for" - " inference.Original 'prompt' column is kept but not used for" - " inference." - ) - elif "prompt" not in df.columns and "request" in df.columns: - logger.info("The 'request' column will be replaced with templated prompts.") - - templated_prompts = [] - for _, row in df.iterrows(): - templated_prompts.append(prompt_template.assemble(**row.to_dict())) - - df["request"] = templated_prompts def _run_inference_internal( api_client: BaseApiClient, model: Union[Callable[[Any], Any], str], prompt_dataset: pd.DataFrame, - dest: Optional[str] = None, config: Optional[genai_types.GenerateContentConfig] = None, ) -> pd.DataFrame: """Runs inference on a given dataset using the specified model or function.""" - start_time = time.time() - logger.debug("Starting inference process ...") - - if prompt_dataset.empty: - raise ValueError("Prompt dataset 'prompt_dataset' must not be empty.") if isinstance(model, str): logger.info("Running inference with model name: %s", model) - responses_raw = _run_gemini_inference( + raw_responses = _run_gemini_inference( api_client=api_client, model=model, prompt_dataset=prompt_dataset, config=config, ) processed_responses = [] - for resp_item in responses_raw: + for resp_item in raw_responses: if isinstance(resp_item, genai_types.GenerateContentResponse): text_response = resp_item.text processed_responses.append( @@ -406,13 +309,8 @@ def _run_inference_internal( error_payload = { "error": "Unexpected response type from Gemini inference", "response_type": str(type(resp_item)), + "details": str(resp_item), } - if hasattr(resp_item, "model_dump_json"): - error_payload["details"] = resp_item.model_dump_json() - elif isinstance(resp_item, (dict, list)): - error_payload["details"] = json.dumps(resp_item) - else: - error_payload["details"] = str(resp_item) processed_responses.append(json.dumps(error_payload)) responses = processed_responses @@ -439,20 +337,12 @@ def _run_inference_internal( " name) or Callable." ) - if len(prompt_dataset) != len(responses): - logger.error( - "Critical prompt/response count mismatch: %d prompts vs %d responses." - " This indicates an issue in response collection.", - len(prompt_dataset), - len(responses), + if len(responses) != len(prompt_dataset): + raise RuntimeError( + "Critical prompt/response count mismatch: %d prompts vs %d" + " responses. This indicates an issue in response collection." + % (len(prompt_dataset), len(responses)) ) - if len(responses) < len(prompt_dataset): - responses.extend( - [json.dumps({"error": "Missing response"})] - * (len(prompt_dataset) - len(responses)) - ) - else: - responses = responses[: len(prompt_dataset)] results_df_responses_only = pd.DataFrame( { @@ -467,40 +357,62 @@ def _run_inference_internal( [prompt_dataset_indexed, results_df_responses_only_indexed], axis=1 ) - if dest: - file_name = "inference_results.jsonl" - full_dest_path = dest - is_gcs_path = dest.startswith(_evals_utils.GCS_PREFIX) + return results_df - if is_gcs_path: - if not dest.endswith("/"): - pass - else: - full_dest_path = os.path.join(dest, file_name) - else: - if os.path.isdir(dest): - full_dest_path = os.path.join(dest, file_name) - os.makedirs(os.path.dirname(full_dest_path), exist_ok=True) +def _apply_prompt_template( + df: pd.DataFrame, prompt_template: types.PromptTemplate +) -> None: + """Applies a prompt template to a DataFrame. - logger.info("Saving inference results to: %s", full_dest_path) - try: - if is_gcs_path: - _evals_utils.GcsUtils(api_client=api_client).upload_dataframe( - df=results_df, - gcs_destination_blob_path=full_dest_path, - file_type="jsonl", - ) - logger.info("Results saved to GCS: %s", full_dest_path) - else: - results_df.to_json(full_dest_path, orient="records", lines=True) - logger.info("Results saved locally to: %s", full_dest_path) - except Exception as e: # pylint: disable=broad-exception-caught - logger.error("Failed to save results to %s. Error: %s", full_dest_path, e) + The DataFrame is expected to have columns corresponding to the variables + in the prompt_template_str. The result will be in a new 'request' column. - end_time = time.time() - logger.info("Inference completed in %.2f seconds.", end_time - start_time) - return results_df + Args: + df: The input DataFrame to modify. + prompt_template: The prompt template to apply. + + Returns: + None. The DataFrame is modified in place. + """ + missing_vars = [var for var in prompt_template.variables if var not in df.columns] + if missing_vars: + raise ValueError( + "Missing columns in DataFrame for prompt template variables:" + f" {', '.join(missing_vars)}. Available columns:" + f" {', '.join(df.columns.tolist())}" + ) + + if "prompt" in df.columns: + logger.info( + "Templated prompts stored in 'request' and will be used for" + " inference.Original 'prompt' column is kept but not used for" + " inference." + ) + elif "prompt" not in df.columns and "request" in df.columns: + logger.info("The 'request' column will be replaced with templated prompts.") + + templated_prompts = [] + for _, row in df.iterrows(): + templated_prompts.append(prompt_template.assemble(**row.to_dict())) + + df["request"] = templated_prompts + + +def _load_dataframe( + api_client: BaseApiClient, src: Union[str, pd.DataFrame] +) -> pd.DataFrame: + """Loads and prepares the prompt dataset for inference.""" + logger.info("Loading prompt dataset from: %s", src) + try: + loader = _evals_utils.EvalDatasetLoader(api_client=api_client) + dataset_list_of_dicts = loader.load(src) + if not dataset_list_of_dicts: + raise ValueError("Prompt dataset 'prompt_dataset' must not be empty.") + return pd.DataFrame(dataset_list_of_dicts) + except Exception as e: + logger.error("Failed to load prompt dataset from source: %s. Error: %s", src, e) + raise e def _execute_inference( @@ -512,6 +424,23 @@ def _execute_inference( config: Optional[genai_types.GenerateContentConfig] = None, prompt_template: Optional[Union[str, types.PromptTemplateOrDict]] = None, ) -> pd.DataFrame: + """Executes inference on a given dataset using the specified model. + + Args: + api_client: The API client. + model: The model to use for inference. Can be a callable function or a + string representing a model. + src: The source of the dataset to use for inference. Can be a string + representing a file path or a pandas DataFrame. + dest: The destination to save the inference results. Can be a string + representing a file path or a GCS URI. + config: The generation configuration for the model. + prompt_template: The prompt template to use for inference. + + Returns: + A pandas DataFrame containing the inference results. + """ + if not api_client: raise ValueError("'api_client' instance must be provided.") prompt_dataset = _load_dataframe(api_client, src) @@ -535,13 +464,47 @@ def _execute_inference( f"Found columns: {prompt_dataset.columns.tolist()}" ) + start_time = time.time() + logger.debug("Starting inference process ...") results_df = _run_inference_internal( api_client=api_client, model=model, prompt_dataset=prompt_dataset, - dest=dest, config=config, ) + end_time = time.time() + logger.info("Inference completed in %.2f seconds.", end_time - start_time) + + if dest: + file_name = "inference_results.jsonl" + full_dest_path = dest + is_gcs_path = dest.startswith(_evals_utils.GCS_PREFIX) + + if is_gcs_path: + if not dest.endswith("/"): + pass + else: + full_dest_path = os.path.join(dest, file_name) + else: + if os.path.isdir(dest): + full_dest_path = os.path.join(dest, file_name) + + os.makedirs(os.path.dirname(full_dest_path), exist_ok=True) + + logger.info("Saving inference results to: %s", full_dest_path) + try: + if is_gcs_path: + _evals_utils.GcsUtils(api_client=api_client).upload_dataframe( + df=results_df, + gcs_destination_blob_path=full_dest_path, + file_type="jsonl", + ) + logger.info("Results saved to GCS: %s", full_dest_path) + else: + results_df.to_json(full_dest_path, orient="records", lines=True) + logger.info("Results saved locally to: %s", full_dest_path) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Failed to save results to %s. Error: %s", full_dest_path, e) return types.EvaluationDataset(eval_dataset_df=results_df) From 701b8d40ba7b1265051a8b6a507e8a6b8e242a54 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 12 Jun 2025 19:08:50 -0700 Subject: [PATCH 06/24] feat: Adding VAPO Prompt Optimizer (PO-data) to the genai SDK. PiperOrigin-RevId: 770884761 --- .../vertexai/genai/test_prompt_optimizer.py | 104 +++++++ vertexai/_genai/client.py | 16 +- vertexai/_genai/evals.py | 10 +- vertexai/_genai/prompt_optimizer.py | 259 ++++++++++++++++++ vertexai/_genai/types.py | 81 +++++- 5 files changed, 463 insertions(+), 7 deletions(-) create mode 100644 tests/unit/vertexai/genai/test_prompt_optimizer.py create mode 100644 vertexai/_genai/prompt_optimizer.py diff --git a/tests/unit/vertexai/genai/test_prompt_optimizer.py b/tests/unit/vertexai/genai/test_prompt_optimizer.py new file mode 100644 index 0000000000..73f82736d4 --- /dev/null +++ b/tests/unit/vertexai/genai/test_prompt_optimizer.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation +import copy +import importlib +from unittest import mock + +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform import initializer as aiplatform_initializer +from google.cloud.aiplatform.compat.services import job_service_client +from google.cloud.aiplatform.compat.types import ( + custom_job as gca_custom_job_compat, +) +from google.cloud.aiplatform.compat.types import io as gca_io_compat +from google.cloud.aiplatform.compat.types import ( + job_state as gca_job_state_compat, +) +from google.cloud.aiplatform.utils import gcs_utils +from google.genai import client +import pytest + + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +pytestmark = pytest.mark.usefixtures("google_auth_mock") +_TEST_PROJECT_NUMBER = "12345678" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345" +_TEST_BASE_OUTPUT_DIR = "gs://test_bucket/test_base_output_dir" + +_TEST_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob( + display_name=_TEST_DISPLAY_NAME, + job_spec={ + "base_output_directory": gca_io_compat.GcsDestination( + output_uri_prefix=_TEST_BASE_OUTPUT_DIR + ), + }, + labels={"trained_by_vertex_ai": "true"}, +) + + +@pytest.fixture +def mock_create_custom_job(): + with mock.patch.object( + job_service_client.JobServiceClient, "create_custom_job" + ) as create_custom_job_mock: + custom_job_proto = copy.deepcopy(_TEST_CUSTOM_JOB_PROTO) + custom_job_proto.name = _TEST_DISPLAY_NAME + custom_job_proto.state = gca_job_state_compat.JobState.JOB_STATE_PENDING + create_custom_job_mock.return_value = custom_job_proto + yield create_custom_job_mock + + +class TestPromptOptimizer: + """Unit tests for the Prompt Optimizer client.""" + + def setup_method(self): + importlib.reload(aiplatform_initializer) + importlib.reload(aiplatform) + importlib.reload(vertexai) + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + @pytest.mark.usefixtures("google_auth_mock") + def test_prompt_optimizer_client(self): + test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + assert test_client is not None + assert test_client._api_client.vertexai + assert test_client._api_client.project == _TEST_PROJECT + assert test_client._api_client.location == _TEST_LOCATION + + @mock.patch.object(client.Client, "_get_api_client") + @mock.patch.object( + gcs_utils.resource_manager_utils, "get_project_number", return_value=12345 + ) + def test_prompt_optimizer_optimize( + self, mock_get_project_number, mock_client, mock_create_custom_job + ): + """Test that prompt_optimizer.optimize method creates a custom job.""" + test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client.prompt_optimizer.optimize( + method="vapo", + config={ + "config_path": "gs://ssusie-vapo-sdk-test/config.json", + "wait_for_completion": False, + }, + ) + mock_create_custom_job.assert_called_once() + mock_get_project_number.assert_called_once() diff --git a/vertexai/_genai/client.py b/vertexai/_genai/client.py index 3102c653c8..6371eb727e 100644 --- a/vertexai/_genai/client.py +++ b/vertexai/_genai/client.py @@ -14,7 +14,6 @@ # import importlib - from typing import Optional, Union import google.auth @@ -24,6 +23,7 @@ class AsyncClient: + """Async Client for the GenAI SDK.""" def __init__(self, api_client: client.Client): @@ -50,6 +50,8 @@ def evals(self): ) from e return self._evals.AsyncEvals(self._api_client) + # TODO(b/424176979): add async prompt optimizer here. + class Client: """Client for the GenAI SDK. @@ -101,6 +103,7 @@ def __init__( http_options=http_options, ) self._evals = None + self._prompt_optimizer = None @property @_common.experimental_warning( @@ -120,3 +123,14 @@ def evals(self): "google-cloud-aiplatform[evaluation]" ) from e return self._evals.Evals(self._api_client) + + @property + @_common.experimental_warning( + "The Vertex SDK GenAI prompt optimizer module is experimental, " + "and may change in future versions." + ) + def prompt_optimizer(self): + self._prompt_optimizer = importlib.import_module( + ".prompt_optimizer", __package__ + ) + return self._prompt_optimizer.PromptOptimizer(self._api_client) diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py index 9bd2637313..093b6e6def 100644 --- a/vertexai/_genai/evals.py +++ b/vertexai/_genai/evals.py @@ -18,6 +18,7 @@ import logging from typing import Any, Callable, Optional, Union from urllib.parse import urlencode + from google.genai import _api_module from google.genai import _common from google.genai import types as genai_types @@ -25,6 +26,7 @@ from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv import pandas as pd + from . import _evals_common from . import types @@ -1238,9 +1240,11 @@ def evaluate( config = types.EvaluateMethodConfig.model_validate(config) if isinstance(dataset, list): dataset = [ - types.EvaluationDataset.model_validate(ds_item) - if isinstance(ds_item, dict) - else ds_item + ( + types.EvaluationDataset.model_validate(ds_item) + if isinstance(ds_item, dict) + else ds_item + ) for ds_item in dataset ] else: diff --git a/vertexai/_genai/prompt_optimizer.py b/vertexai/_genai/prompt_optimizer.py new file mode 100644 index 0000000000..6065260d9d --- /dev/null +++ b/vertexai/_genai/prompt_optimizer.py @@ -0,0 +1,259 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import datetime +import logging +from typing import Any, Optional, Union +from urllib.parse import urlencode + +from google.cloud import aiplatform +from google.genai import _api_module +from google.genai import _common +from google.genai import types as genai_types +from google.genai._api_client import BaseApiClient +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv + +from . import types + + +logger = logging.getLogger("vertexai_genai.promptoptimizer") + + +def _OptimizeRequestParameters_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _OptimizeResponse_from_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + +class PromptOptimizer(_api_module.BaseModule): + """Prompt Optimizer""" + + def optimize_dummy( + self, *, config: Optional[types.OptimizeConfigOrDict] = None + ) -> types.OptimizeResponse: + """Optimiza a multiple prompts.""" + + parameter_model = types._OptimizeRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _OptimizeRequestParameters_to_vertex( + self._api_client, parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":optimize".format_map(request_url_dict) + else: + path = ":optimize" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + "post", path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _OptimizeResponse_from_vertex( + self._api_client, response_dict + ) + + return_value = types.OptimizeResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + return return_value + + """Prompt Optimizer PO-Data.""" + + def _create_custom_job( + self, + display_name: str, + container_uri: str, + bucket: str, + container_args: dict[str, str], + service_account: str, + ) -> aiplatform.CustomJob: + """Create a custom jobs.""" + args = ["--%s=%s" % (k, v) for k, v in container_args.items()] + worker_pool_specs = [ + { + "replica_count": 1, + "container_spec": { + "image_uri": container_uri, + "args": args, + }, + "machine_spec": { + "machine_type": "n1-standard-4", + }, + } + ] + + custom_job = aiplatform.CustomJob( + display_name=display_name, + worker_pool_specs=worker_pool_specs, + staging_bucket=bucket, + ) + custom_job.submit(service_account=service_account) + return custom_job + + def optimize( + self, + method: str, + config: types.PromptOptimizerVAPOConfig, + ) -> aiplatform.CustomJob: + """Call PO-Data optimizer. + + Args: + method: The method for optimizing multiple prompts. + config: The config to use. Config consists of the following fields: - + config_path: The gcs path to the config file, e.g. + gs://bucket/config.json. - wait_for_completion: Optional. Whether to + wait for the job to complete. Default is True. + """ + + if method != "vapo": + raise ValueError("Only vapo methods is currently supported.") + + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + display_name = f"vapo-optimizer-{timestamp}" + wait_for_completion = config["wait_for_completion"] + bucket = "/".join(config["config_path"].split("/")[:-1]) + + container_uri = "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0" + + region = self._api_client.location + project = self._api_client.project + project_number = aiplatform.utils.resource_manager_utils.get_project_number( + project + ) + service_account = f"{project_number}-compute@developer.gserviceaccount.com" + + job = self._create_custom_job( + display_name, + container_uri, + bucket, + { + "config": config["config_path"], + }, + service_account, + ) + + # Get the job resource name + job_resource_name = job.resource_name + job_id = job_resource_name.split("/")[-1] + logger.info("Job created: %s", job.resource_name) + + # Construct the dashboard URL + dashboard_url = f"https://pantheon.corp.google.com/vertex-ai/locations/{region}/training/{job_id}/cpu?e=13802955&project={project}" + logger.info("View the job status at: %s", dashboard_url) + + if wait_for_completion: + logger.info("Waiting for the job to finish: %s", job.display_name) + job.wait_for_completion() + return job + + +class AsyncPromptOptimizer(_api_module.BaseModule): + """Prompt Optimizer""" + + async def optimize_dummy( + self, *, config: Optional[types.OptimizeConfigOrDict] = None + ) -> types.OptimizeResponse: + """Optimiza a multiple prompts.""" + + parameter_model = types._OptimizeRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _OptimizeRequestParameters_to_vertex( + self._api_client, parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":optimize".format_map(request_url_dict) + else: + path = ":optimize" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _OptimizeResponse_from_vertex( + self._api_client, response_dict + ) + + return_value = types.OptimizeResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + return return_value diff --git a/vertexai/_genai/types.py b/vertexai/_genai/types.py index 506267c1dc..df5105053d 100644 --- a/vertexai/_genai/types.py +++ b/vertexai/_genai/types.py @@ -2176,9 +2176,9 @@ def to_yaml_file(self, file_path: str, version: Optional[str] = None) -> None: exclude_unset=True, exclude_none=True, mode="json", - exclude=fields_to_exclude_callables - if fields_to_exclude_callables - else None, + exclude=( + fields_to_exclude_callables if fields_to_exclude_callables else None + ), ) if version: @@ -2382,6 +2382,81 @@ class EvaluateDatasetOperationDict(TypedDict, total=False): ] +class OptimizeConfig(_common.BaseModel): + """Config for Prompt Optimizer.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class OptimizeConfigDict(TypedDict, total=False): + """Config for Prompt Optimizer.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + +OptimizeConfigOrDict = Union[OptimizeConfig, OptimizeConfigDict] + + +class _OptimizeRequestParameters(_common.BaseModel): + """Parameters for the optimize_prompt method.""" + + config: Optional[OptimizeConfig] = Field(default=None, description="""""") + + +class _OptimizeRequestParametersDict(TypedDict, total=False): + """Parameters for the optimize_prompt method.""" + + config: Optional[OptimizeConfigDict] + """""" + + +_OptimizeRequestParametersOrDict = Union[ + _OptimizeRequestParameters, _OptimizeRequestParametersDict +] + + +class OptimizeResponse(_common.BaseModel): + """Response for the optimize_prompt method.""" + + pass + + +class OptimizeResponseDict(TypedDict, total=False): + """Response for the optimize_prompt method.""" + + pass + + +OptimizeResponseOrDict = Union[OptimizeResponse, OptimizeResponseDict] + + +class PromptOptimizerVAPOConfig(_common.BaseModel): + """VAPO Prompt Optimizer Config.""" + + config_path: Optional[str] = Field( + default=None, description="""The gcs path to the config file.""" + ) + wait_for_completion: Optional[bool] = Field(default=None, description="""""") + + +class PromptOptimizerVAPOConfigDict(TypedDict, total=False): + """VAPO Prompt Optimizer Config.""" + + config_path: Optional[str] + """The gcs path to the config file.""" + + wait_for_completion: Optional[bool] + """""" + + +PromptOptimizerVAPOConfigOrDict = Union[ + PromptOptimizerVAPOConfig, PromptOptimizerVAPOConfigDict +] + + class PromptTemplate(_common.BaseModel): """A prompt template for creating prompts with variables.""" From d69ef6b524014f54b79129e1417f985293fd5483 Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Fri, 13 Jun 2025 12:56:24 -0700 Subject: [PATCH 07/24] chore: internal change PiperOrigin-RevId: 771194451 --- testing/constraints-3.12.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/testing/constraints-3.12.txt b/testing/constraints-3.12.txt index 64a02b12c6..7403aeafd3 100644 --- a/testing/constraints-3.12.txt +++ b/testing/constraints-3.12.txt @@ -11,4 +11,5 @@ pytest-xdist==3.3.1 # Pinned to unbreak unit tests ray==2.5.0 # Pinned until 2.9.3 is verified for Ray tests ipython==8.22.2 # Pinned to unbreak TypeAliasType import error google-adk==0.0.2 -google-genai>=1.10.0 \ No newline at end of file +google-genai>=1.10.0 +google-vizier==0.1.21 \ No newline at end of file From 9b48d24ab90c57d4a49b3adf22a79cffbe065351 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Fri, 13 Jun 2025 15:11:39 -0700 Subject: [PATCH 08/24] Copybara import of the project: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit -- 73bda4c66d5bef69fa946c915a31f0dd02e5a271 by Owl Bot : feat: Introduce RagFileMetadataConfig for importing metadata to Rag PiperOrigin-RevId: 770274285 Source-Link: https://github.com/googleapis/googleapis/commit/4cdc2aa1af11046c38ee1a4b1f6f4b20f0d49e2b Source-Link: https://github.com/googleapis/googleapis-gen/commit/9dbe3e0f2dc959cc26c594c74a9a93e408c158b6 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiOWRiZTNlMGYyZGM5NTljYzI2YzU5NGM3NGE5YTkzZTQwOGMxNThiNiJ9 -- 08e37c3f675a54a2d435cad2032d526f1f96cd4a by Owl Bot : feat: add EncryptionSpec field for RagCorpus CMEK feature to v1 PiperOrigin-RevId: 770837205 Source-Link: https://github.com/googleapis/googleapis/commit/3a45aa38968f4fe537cd3837c0e95af8c189b11b Source-Link: https://github.com/googleapis/googleapis-gen/commit/6d6b54fc3e11bd79c520a4df8cb1561a2c6aa149 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiNmQ2YjU0ZmMzZTExYmQ3OWM1MjBhNGRmOGNiMTU2MWEyYzZhYTE0OSJ9 -- 2e54a6891784bf4f714ef96511f72c97f47d33f6 by Owl Bot : 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/5422 from googleapis:owl-bot-copy 8660dc9fad939027807d1ed7e189e4ca87fe2ad2 PiperOrigin-RevId: 771240575 --- .../services/migration_service/client.py | 18 +-- .../vertex_rag_data_service/async_client.py | 1 + .../vertex_rag_data_service/client.py | 1 + .../cloud/aiplatform_v1/types/tuning_job.py | 8 +- .../aiplatform_v1/types/vertex_rag_data.py | 12 ++ google/cloud/aiplatform_v1beta1/__init__.py | 2 + .../aiplatform_v1beta1/types/__init__.py | 2 + .../aiplatform_v1beta1/types/tuning_job.py | 9 +- .../types/vertex_rag_data.py | 133 ++++++++++++++++++ ...t_metadata_google.cloud.aiplatform.v1.json | 2 +- ...adata_google.cloud.aiplatform.v1beta1.json | 2 +- .../aiplatform_v1/test_migration_service.py | 26 ++-- .../test_vertex_rag_data_service.py | 5 + .../test_vertex_rag_data_service.py | 9 ++ 14 files changed, 202 insertions(+), 28 deletions(-) diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index aba8658b32..735b4d1655 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -265,40 +265,40 @@ def parse_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/async_client.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/async_client.py index 51cdffea74..e0213cf348 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/async_client.py @@ -48,6 +48,7 @@ from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.vertex_rag_data_service import pagers +from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import operation as gca_operation from google.cloud.aiplatform_v1.types import vertex_rag_data diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/client.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/client.py index a673f1f189..ee407688da 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/client.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/client.py @@ -64,6 +64,7 @@ from google.api_core import operation as gac_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1.services.vertex_rag_data_service import pagers +from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import operation as gca_operation from google.cloud.aiplatform_v1.types import vertex_rag_data diff --git a/google/cloud/aiplatform_v1/types/tuning_job.py b/google/cloud/aiplatform_v1/types/tuning_job.py index 43173fc6c4..0ae93dc179 100644 --- a/google/cloud/aiplatform_v1/types/tuning_job.py +++ b/google/cloud/aiplatform_v1/types/tuning_job.py @@ -529,9 +529,13 @@ class SupervisedTuningSpec(proto.Message): Attributes: training_dataset_uri (str): - Required. Training dataset used for tuning. The dataset can be specified as either a Cloud Storage path to a JSONL file or as the resource name of a Vertex Multimodal Dataset. + Required. Cloud Storage path to file + containing training dataset for tuning. The + dataset must be formatted as a JSONL file. validation_dataset_uri (str): - Optional. Validation dataset used for tuning. The dataset can be specified as either a Cloud Storage path to a JSONL file or as the resource name of a Vertex Multimodal Dataset. + Optional. Cloud Storage path to file + containing validation dataset for tuning. The + dataset must be formatted as a JSONL file. hyper_parameters (google.cloud.aiplatform_v1.types.SupervisedHyperParameters): Optional. Hyperparameters for SFT. export_last_checkpoint_only (bool): diff --git a/google/cloud/aiplatform_v1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1/types/vertex_rag_data.py index 33fdc7962d..fcab74428f 100644 --- a/google/cloud/aiplatform_v1/types/vertex_rag_data.py +++ b/google/cloud/aiplatform_v1/types/vertex_rag_data.py @@ -20,6 +20,7 @@ import proto # type: ignore from google.cloud.aiplatform_v1.types import api_auth as gca_api_auth +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1.types import io from google.protobuf import timestamp_pb2 # type: ignore @@ -414,6 +415,12 @@ class RagCorpus(proto.Message): was last updated. corpus_status (google.cloud.aiplatform_v1.types.CorpusStatus): Output only. RagCorpus state. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Optional. Immutable. The CMEK key name used + to encrypt at-rest data related to this Corpus. + Only applicable to RagManagedDb option for + Vector DB. This field can only be set at corpus + creation time, and cannot be updated or deleted. """ vector_db_config: "RagVectorDbConfig" = proto.Field( @@ -455,6 +462,11 @@ class RagCorpus(proto.Message): number=8, message="CorpusStatus", ) + encryption_spec: gca_encryption_spec.EncryptionSpec = proto.Field( + proto.MESSAGE, + number=12, + message=gca_encryption_spec.EncryptionSpec, + ) class RagFile(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 3b64f7f3ba..2782cd4534 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -1202,6 +1202,7 @@ from .types.vertex_rag_data import RagEngineConfig from .types.vertex_rag_data import RagFile from .types.vertex_rag_data import RagFileChunkingConfig +from .types.vertex_rag_data import RagFileMetadataConfig from .types.vertex_rag_data import RagFileParsingConfig from .types.vertex_rag_data import RagFileTransformationConfig from .types.vertex_rag_data import RagManagedDbConfig @@ -2128,6 +2129,7 @@ "RagEngineConfig", "RagFile", "RagFileChunkingConfig", + "RagFileMetadataConfig", "RagFileParsingConfig", "RagFileTransformationConfig", "RagManagedDbConfig", diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index bcf0f9267c..664e401502 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -1355,6 +1355,7 @@ RagEngineConfig, RagFile, RagFileChunkingConfig, + RagFileMetadataConfig, RagFileParsingConfig, RagFileTransformationConfig, RagManagedDbConfig, @@ -2499,6 +2500,7 @@ "RagEngineConfig", "RagFile", "RagFileChunkingConfig", + "RagFileMetadataConfig", "RagFileParsingConfig", "RagFileTransformationConfig", "RagManagedDbConfig", diff --git a/google/cloud/aiplatform_v1beta1/types/tuning_job.py b/google/cloud/aiplatform_v1beta1/types/tuning_job.py index d26488ad46..f6661202bb 100644 --- a/google/cloud/aiplatform_v1beta1/types/tuning_job.py +++ b/google/cloud/aiplatform_v1beta1/types/tuning_job.py @@ -764,9 +764,14 @@ class SupervisedTuningSpec(proto.Message): Attributes: training_dataset_uri (str): - Required. Training dataset used for tuning. The dataset can be specified as either a Cloud Storage path to a JSONL file or as the resource name of a Vertex Multimodal Dataset. + Required. Cloud Storage path to file + containing training dataset for tuning. The + dataset must be formatted as a JSONL file. validation_dataset_uri (str): - Optional. Validation dataset used for tuning. The dataset can be specified as either a Cloud Storage path to a JSONL file or as the resource name of a Vertex Multimodal Dataset.a1.types.SupervisedHyperParameters): + Optional. Cloud Storage path to file + containing validation dataset for tuning. The + dataset must be formatted as a JSONL file. + hyper_parameters (google.cloud.aiplatform_v1beta1.types.SupervisedHyperParameters): Optional. Hyperparameters for SFT. export_last_checkpoint_only (bool): Optional. If set to true, disable diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py index 183515f8dd..557ce4d6bd 100644 --- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py +++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py @@ -39,6 +39,7 @@ "RagFileChunkingConfig", "RagFileTransformationConfig", "RagFileParsingConfig", + "RagFileMetadataConfig", "UploadRagFileConfig", "ImportRagFilesConfig", "RagManagedDbConfig", @@ -776,6 +777,9 @@ class RagFile(proto.Message): last updated. file_status (google.cloud.aiplatform_v1beta1.types.FileStatus): Output only. State of the RagFile. + user_metadata (str): + Output only. The metadata for metadata + search. The contents will be be in JSON format. """ class RagFileType(proto.Enum): @@ -865,6 +869,10 @@ class RagFileType(proto.Enum): number=13, message="FileStatus", ) + user_metadata: str = proto.Field( + proto.STRING, + number=15, + ) class RagChunk(proto.Message): @@ -1136,6 +1144,103 @@ class LlmParser(proto.Message): ) +class RagFileMetadataConfig(proto.Message): + r"""Metadata config for RagFile. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + gcs_metadata_schema_source (google.cloud.aiplatform_v1beta1.types.GcsSource): + Google Cloud Storage location. Supports importing individual + files as well as entire Google Cloud Storage directories. + Sample formats: + + - ``gs://bucket_name/my_directory/object_name/metadata_schema.json`` + - ``gs://bucket_name/my_directory`` If providing a + directory, the metadata schema will be read from the + files that ends with "metadata_schema.json" in the + directory. + + This field is a member of `oneof`_ ``metadata_schema_source``. + google_drive_metadata_schema_source (google.cloud.aiplatform_v1beta1.types.GoogleDriveSource): + Google Drive location. Supports importing individual files + as well as Google Drive folders. If providing a folder, the + metadata schema will be read from the files that ends with + "metadata_schema.json" in the directory. + + This field is a member of `oneof`_ ``metadata_schema_source``. + inline_metadata_schema_source (str): + Inline metadata schema source. Must be a JSON + string. + + This field is a member of `oneof`_ ``metadata_schema_source``. + gcs_metadata_source (google.cloud.aiplatform_v1beta1.types.GcsSource): + Google Cloud Storage location. Supports importing individual + files as well as entire Google Cloud Storage directories. + Sample formats: + + - ``gs://bucket_name/my_directory/object_name/metadata.json`` + - ``gs://bucket_name/my_directory`` If providing a + directory, the metadata will be read from the files that + ends with "metadata.json" in the directory. + + This field is a member of `oneof`_ ``metadata_source``. + google_drive_metadata_source (google.cloud.aiplatform_v1beta1.types.GoogleDriveSource): + Google Drive location. Supports importing + individual files as well as Google Drive + folders. If providing a directory, the metadata + will be read from the files that ends with + "metadata.json" in the directory. + + This field is a member of `oneof`_ ``metadata_source``. + inline_metadata_source (str): + Inline metadata source. Must be a JSON + string. + + This field is a member of `oneof`_ ``metadata_source``. + """ + + gcs_metadata_schema_source: io.GcsSource = proto.Field( + proto.MESSAGE, + number=1, + oneof="metadata_schema_source", + message=io.GcsSource, + ) + google_drive_metadata_schema_source: io.GoogleDriveSource = proto.Field( + proto.MESSAGE, + number=2, + oneof="metadata_schema_source", + message=io.GoogleDriveSource, + ) + inline_metadata_schema_source: str = proto.Field( + proto.STRING, + number=3, + oneof="metadata_schema_source", + ) + gcs_metadata_source: io.GcsSource = proto.Field( + proto.MESSAGE, + number=4, + oneof="metadata_source", + message=io.GcsSource, + ) + google_drive_metadata_source: io.GoogleDriveSource = proto.Field( + proto.MESSAGE, + number=5, + oneof="metadata_source", + message=io.GoogleDriveSource, + ) + inline_metadata_source: str = proto.Field( + proto.STRING, + number=6, + oneof="metadata_source", + ) + + class UploadRagFileConfig(proto.Message): r"""Config for uploading RagFile. @@ -1146,6 +1251,15 @@ class UploadRagFileConfig(proto.Message): rag_file_transformation_config (google.cloud.aiplatform_v1beta1.types.RagFileTransformationConfig): Specifies the transformation config for RagFiles. + rag_file_metadata_config (google.cloud.aiplatform_v1beta1.types.RagFileMetadataConfig): + Specifies the metadata config for RagFiles. + Including paths for metadata schema and + metadata. Alteratively, inline metadata schema + and metadata can be provided. + rag_file_parsing_config (google.cloud.aiplatform_v1beta1.types.RagFileParsingConfig): + Optional. Specifies the parsing config for + RagFiles. RAG will use the default parser if + this field is not set. """ rag_file_chunking_config: "RagFileChunkingConfig" = proto.Field( @@ -1158,6 +1272,16 @@ class UploadRagFileConfig(proto.Message): number=3, message="RagFileTransformationConfig", ) + rag_file_metadata_config: "RagFileMetadataConfig" = proto.Field( + proto.MESSAGE, + number=4, + message="RagFileMetadataConfig", + ) + rag_file_parsing_config: "RagFileParsingConfig" = proto.Field( + proto.MESSAGE, + number=5, + message="RagFileParsingConfig", + ) class ImportRagFilesConfig(proto.Message): @@ -1241,6 +1365,10 @@ class ImportRagFilesConfig(proto.Message): Optional. Specifies the parsing config for RagFiles. RAG will use the default parser if this field is not set. + rag_file_metadata_config (google.cloud.aiplatform_v1beta1.types.RagFileMetadataConfig): + Specifies the metadata config for RagFiles. + Including paths for metadata schema and + metadata. max_embedding_requests_per_min (int): Optional. The max number of queries per minute that this job is allowed to make to the @@ -1338,6 +1466,11 @@ class ImportRagFilesConfig(proto.Message): number=8, message="RagFileParsingConfig", ) + rag_file_metadata_config: "RagFileMetadataConfig" = proto.Field( + proto.MESSAGE, + number=17, + message="RagFileMetadataConfig", + ) max_embedding_requests_per_min: int = proto.Field( proto.INT32, number=5, diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index 0e535e9098..cbfa30ab93 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.97.0" + "version": "0.1.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index cafa126840..418c8046bc 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.97.0" + "version": "0.1.0" }, "snippets": [ { diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index cc5515c8d4..9e8dced432 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -5424,22 +5424,19 @@ def test_parse_dataset_path(): def test_dataset_path(): project = "squid" - location = "clam" - dataset = "whelk" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + dataset = "clam" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "octopus", - "location": "oyster", - "dataset": "nudibranch", + "project": "whelk", + "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -5449,19 +5446,22 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "cuttlefish" - dataset = "mussel" - expected = "projects/{project}/datasets/{dataset}".format( + project = "oyster" + location = "nudibranch" + dataset = "cuttlefish" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", + "project": "mussel", + "location": "winkle", "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1/test_vertex_rag_data_service.py index 24a7a2197c..f9e920487e 100644 --- a/tests/unit/gapic/aiplatform_v1/test_vertex_rag_data_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_vertex_rag_data_service.py @@ -75,6 +75,7 @@ from google.cloud.aiplatform_v1.services.vertex_rag_data_service import pagers from google.cloud.aiplatform_v1.services.vertex_rag_data_service import transports from google.cloud.aiplatform_v1.types import api_auth +from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import io from google.cloud.aiplatform_v1.types import operation as gca_operation from google.cloud.aiplatform_v1.types import vertex_rag_data @@ -7807,6 +7808,7 @@ def test_create_rag_corpus_rest_call_success(request_type): "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, "corpus_status": {"state": 1, "error_status": "error_status_value"}, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -8036,6 +8038,7 @@ def test_update_rag_corpus_rest_call_success(request_type): "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, "corpus_status": {"state": 1, "error_status": "error_status_value"}, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -10176,6 +10179,7 @@ async def test_create_rag_corpus_rest_asyncio_call_success(request_type): "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, "corpus_status": {"state": 1, "error_status": "error_status_value"}, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -10421,6 +10425,7 @@ async def test_update_rag_corpus_rest_asyncio_call_success(request_type): "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, "corpus_status": {"state": 1, "error_status": "error_status_value"}, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py index 140f11d9bc..91798a544d 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py @@ -3922,6 +3922,7 @@ def test_get_rag_file(request_type, transport: str = "grpc"): description="description_value", size_bytes=1089, rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, + user_metadata="user_metadata_value", ) response = client.get_rag_file(request) @@ -3940,6 +3941,7 @@ def test_get_rag_file(request_type, transport: str = "grpc"): assert ( response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT ) + assert response.user_metadata == "user_metadata_value" def test_get_rag_file_non_empty_request_with_auto_populated_field(): @@ -4071,6 +4073,7 @@ async def test_get_rag_file_async( description="description_value", size_bytes=1089, rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, + user_metadata="user_metadata_value", ) ) response = await client.get_rag_file(request) @@ -4090,6 +4093,7 @@ async def test_get_rag_file_async( assert ( response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT ) + assert response.user_metadata == "user_metadata_value" @pytest.mark.asyncio @@ -8771,6 +8775,7 @@ async def test_get_rag_file_empty_call_grpc_asyncio(): description="description_value", size_bytes=1089, rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, + user_metadata="user_metadata_value", ) ) await client.get_rag_file(request=None) @@ -10107,6 +10112,7 @@ def test_get_rag_file_rest_call_success(request_type): description="description_value", size_bytes=1089, rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, + user_metadata="user_metadata_value", ) # Wrap the value into a proper Response obj @@ -10130,6 +10136,7 @@ def test_get_rag_file_rest_call_success(request_type): assert ( response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT ) + assert response.user_metadata == "user_metadata_value" @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -13031,6 +13038,7 @@ async def test_get_rag_file_rest_asyncio_call_success(request_type): description="description_value", size_bytes=1089, rag_file_type=vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT, + user_metadata="user_metadata_value", ) # Wrap the value into a proper Response obj @@ -13056,6 +13064,7 @@ async def test_get_rag_file_rest_asyncio_call_success(request_type): assert ( response.rag_file_type == vertex_rag_data.RagFile.RagFileType.RAG_FILE_TYPE_TXT ) + assert response.user_metadata == "user_metadata_value" @pytest.mark.asyncio From f1e17a6b35fb31b7a5eb589a132df5df0ad7e3e4 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Mon, 16 Jun 2025 07:25:27 -0700 Subject: [PATCH 09/24] docs: add GenAI client examples to readme PiperOrigin-RevId: 772016883 --- README.rst | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/README.rst b/README.rst index 4d872cdc61..425931cf0a 100644 --- a/README.rst +++ b/README.rst @@ -10,6 +10,65 @@ Gemini API and Generative AI on Vertex AI For Gemini API and Generative AI on Vertex AI, please reference `Vertex Generative AI SDK for Python`_ .. _Vertex Generative AI SDK for Python: https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest +Using the Google Gen AI SDK client from the Vertex AI SDK (Experimental) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To use features from the Google Gen AI SDK from the Vertex AI SDK, you can instantiate the client with the following: + +.. code-block:: Python + + import vertexai + from vertexai import types + + # Instantiate GenAI client from Vertex SDK + # Replace with your project ID and location + client = vertexai.Client(project='my-project', location='us-central1') + +See the examples below for guidance on how to use specific features supported by the Gen AI SDK client. + +Gen AI Evaluation +^^^^^^^^^^^^^^^^^ + +To run evaluation, first generate model responses from a set of prompts. + +.. code-block:: Python + + import pandas as pd + + prompts_df = pd.DataFrame({ + "prompt": [ + "What is the capital of France?", + "Write a haiku about a cat.", + "Write a Python function to calculate the factorial of a number.", + "Translate 'How are you?' to French.", + ], + + "reference": [ + "Paris", + "Sunbeam on the floor,\nA furry puddle sleeping,\nTwitching tail tells tales.", + "def factorial(n):\n if n < 0:\n return 'Factorial does not exist for negative numbers'\n elif n == 0:\n return 1\n else:\n fact = 1\n i = 1\n while i <= n:\n fact *= i\n i += 1\n return fact", + "Comment ça va ?", + ] + }) + + inference_results = client.evals.run_inference( + model="gemini-2.5-flash-preview-05-20", + src=prompts_df + ) + +Then run evaluation by providing the inference results and specifying the metric types. + +.. code-block:: Python + + eval_result = client.evals.evaluate( + dataset=inference_results, + metrics=[ + types.Metric(name='exact_match'), + types.Metric(name='rouge_l_sum'), + types.PrebuiltMetric.TEXT_QUALITY, + ] + ) + ----------------------------------------- |GA| |pypi| |versions| |unit-tests| |system-tests| |sample-tests| From b4708de6f50b50eda912e25de5864db9a514a880 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Mon, 16 Jun 2025 13:53:33 -0700 Subject: [PATCH 10/24] chore: fix async GenAI SDK client support PiperOrigin-RevId: 772159610 --- tests/unit/vertexai/genai/test_genai_client.py | 9 ++++++++- vertexai/_genai/client.py | 10 +++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/unit/vertexai/genai/test_genai_client.py b/tests/unit/vertexai/genai/test_genai_client.py index 476519ff36..94f465f809 100644 --- a/tests/unit/vertexai/genai/test_genai_client.py +++ b/tests/unit/vertexai/genai/test_genai_client.py @@ -16,11 +16,12 @@ # pylint: disable=protected-access,bad-continuation import importlib +import pytest from google.cloud import aiplatform import vertexai from google.cloud.aiplatform import initializer as aiplatform_initializer -import pytest + _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" @@ -48,3 +49,9 @@ def test_genai_client(self): assert test_client._api_client.vertexai assert test_client._api_client.project == _TEST_PROJECT assert test_client._api_client.location == _TEST_LOCATION + + @pytest.mark.asyncio + @pytest.mark.usefixtures("google_auth_mock") + async def test_async_client(self): + test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + assert isinstance(test_client.aio, vertexai._genai.client.AsyncClient) diff --git a/vertexai/_genai/client.py b/vertexai/_genai/client.py index 6371eb727e..a0cb0052df 100644 --- a/vertexai/_genai/client.py +++ b/vertexai/_genai/client.py @@ -28,7 +28,6 @@ class AsyncClient: def __init__(self, api_client: client.Client): self._api_client = api_client - self._aio = AsyncClient(self._api_client) self._evals = None @property @@ -102,6 +101,7 @@ def __init__( debug_config=self._debug_config, http_options=http_options, ) + self._aio = AsyncClient(self._api_client) self._evals = None self._prompt_optimizer = None @@ -134,3 +134,11 @@ def prompt_optimizer(self): ".prompt_optimizer", __package__ ) return self._prompt_optimizer.PromptOptimizer(self._api_client) + + @property + @_common.experimental_warning( + "The Vertex SDK GenAI async client is experimental, " + "and may change in future versions." + ) + def aio(self): + return self._aio From 959d79869468c1fa66b7691eb8c4071a5af3eec4 Mon Sep 17 00:00:00 2001 From: Frances Hubis Thoma Date: Mon, 16 Jun 2025 14:42:17 -0700 Subject: [PATCH 11/24] feat: Enable Vertex Multimodal Dataset as input to supervised fine-tuning. PiperOrigin-RevId: 772177718 --- vertexai/tuning/_supervised_tuning.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vertexai/tuning/_supervised_tuning.py b/vertexai/tuning/_supervised_tuning.py index 54cfc4fbc0..7f3f8616c5 100644 --- a/vertexai/tuning/_supervised_tuning.py +++ b/vertexai/tuning/_supervised_tuning.py @@ -15,6 +15,7 @@ from typing import Dict, Literal, Optional, Union +from google.cloud.aiplatform.preview import datasets from google.cloud.aiplatform.utils import _ipython_utils from google.cloud.aiplatform_v1beta1.types import ( tuning_job as gca_tuning_job_types, @@ -26,8 +27,8 @@ def train( *, source_model: Union[str, generative_models.GenerativeModel], - train_dataset: str, - validation_dataset: Optional[str] = None, + train_dataset: Union[str, datasets.MultimodalDataset], + validation_dataset: Optional[Union[str, datasets.MultimodalDataset]] = None, tuned_model_display_name: Optional[str] = None, epochs: Optional[int] = None, learning_rate_multiplier: Optional[float] = None, @@ -38,8 +39,8 @@ def train( Args: source_model (str): Model name for tuning, e.g., "gemini-1.0-pro-002". - train_dataset: Training dataset used for tuning. The dataset can be specified as either a Cloud Storage path to a JSONL file or as the resource name of a Vertex Multimodal Dataset. - validation_dataset: Validation dataset used for tuning. The dataset can be specified as either a Cloud Storage path to a JSONL file or as the resource name of a Vertex Multimodal Dataset. + train_dataset: Training dataset used for tuning. The dataset can be a JSONL file on Google Cloud Storage (specified as its GCS URI) or a Vertex Multimodal Dataset (either as the dataset object itself or as its resource name). + validation_dataset: Validation dataset used for tuning. The dataset can be a JSONL file on Google Cloud Storage (specified as its GCS URI) or a Vertex Multimodal Dataset (either as the dataset object itself or as the resource name). tuned_model_display_name: The display name of the [TunedModel][google.cloud.aiplatform.v1.Model]. The name can be up to 128 characters long and can consist of any UTF-8 characters. @@ -73,6 +74,10 @@ def train( raise ValueError( f"Unsupported adapter size: {adapter_size}. The supported sizes are [1, 4, 8, 16]" ) + if isinstance(train_dataset, datasets.MultimodalDataset): + train_dataset = train_dataset.resource_name + if isinstance(validation_dataset, datasets.MultimodalDataset): + validation_dataset = validation_dataset.resource_name supervised_tuning_spec = gca_tuning_job_types.SupervisedTuningSpec( training_dataset_uri=train_dataset, validation_dataset_uri=validation_dataset, From cd5be581c3ed921d666aa3b75b36a39d314a6c12 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Mon, 16 Jun 2025 15:26:37 -0700 Subject: [PATCH 12/24] chore: update metric prompt builder templates PiperOrigin-RevId: 772194657 --- vertexai/_genai/types.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vertexai/_genai/types.py b/vertexai/_genai/types.py index df5105053d..18d588c474 100644 --- a/vertexai/_genai/types.py +++ b/vertexai/_genai/types.py @@ -2824,12 +2824,16 @@ def _prepare_fields_and_construct_text(cls, data: Any) -> Any: template_parts.extend( [ - "\n# User Inputs and AI-generated Response", - "## User Inputs", + "\n", + "# User Inputs and AI-generated Response", + "## User Prompt", + "{prompt}", + "\n", + "## AI-generated Response", + "{response}", ] ) - template_parts.extend(["## AI-generated Response", "{response}"]) constructed_text = "\n".join(template_parts) data["text"] = constructed_text From 364db79f35c43ac9a68ccef125f6e4e35a1e0372 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Mon, 16 Jun 2025 16:36:36 -0700 Subject: [PATCH 13/24] chore: add `evaluation_dataset` field to the `EvaluationResult` type PiperOrigin-RevId: 772218402 --- tests/unit/vertexai/genai/test_evals.py | 28 ++++++++++- vertexai/_genai/_evals_common.py | 65 +++++++++++++++++-------- vertexai/_genai/evals.py | 8 ++- vertexai/_genai/types.py | 13 +++-- 4 files changed, 84 insertions(+), 30 deletions(-) diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index a31aa18ddf..4bf55cc392 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -632,8 +632,20 @@ def test_inference_with_row_level_config_overrides( prompt_feedback=None, ), ] + + def mock_generate_content_logic(*args, **kwargs): + contents = kwargs.get("contents") + first_part_text = contents[0]["parts"][0]["text"] + if "Placeholder prompt 1" in first_part_text: + return mock_generate_content_responses[0] + elif "Placeholder prompt 2.1" in first_part_text: + return mock_generate_content_responses[1] + elif "Placeholder prompt 3" in first_part_text: + return mock_generate_content_responses[2] + return genai_types.GenerateContentResponse() + mock_models.return_value.generate_content.side_effect = ( - mock_generate_content_responses + mock_generate_content_logic ) inference_result = self.client.evals.run_inference( @@ -2437,6 +2449,7 @@ def test_execute_evaluation_computation_metric( ) assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] assert summary_metric.metric_name == "exact_match" @@ -2477,6 +2490,7 @@ def test_execute_evaluation_translation_metric( metrics=[translation_metric], ) assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] assert summary_metric.metric_name == "comet" @@ -2504,6 +2518,7 @@ def test_execute_evaluation_llm_metric( metrics=[llm_metric], ) assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] assert summary_metric.metric_name == "text_quality" @@ -2542,6 +2557,7 @@ def my_custom_metric_fn(data: dict): metrics=[custom_metric], ) assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] assert summary_metric.metric_name == "my_custom" @@ -2587,6 +2603,7 @@ def test_llm_metric_default_aggregation_mixed_results( ) assert mock_llm_process.call_count == 3 + assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary = result.summary_metrics[0] assert summary.metric_name == "quality" @@ -2637,6 +2654,7 @@ def custom_agg_fn(results: list[vertexai_genai_types.EvalCaseMetricResult]): metrics=[llm_metric], ) assert mock_llm_process.call_count == 2 + assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary = result.summary_metrics[0] assert summary.metric_name == "custom_quality" @@ -2682,6 +2700,7 @@ def custom_agg_fn_error( metrics=[llm_metric], ) assert mock_llm_process.call_count == 2 + assert result.evaluation_dataset == [input_dataset] summary = result.summary_metrics[0] assert summary.metric_name == "error_fallback_quality" assert summary.num_cases_total == 2 @@ -2719,6 +2738,7 @@ def custom_agg_fn_invalid_type( dataset=input_dataset, metrics=[llm_metric], ) + assert result.evaluation_dataset == [input_dataset] summary = result.summary_metrics[0] assert summary.mean_score == 0.8 assert summary.num_cases_valid == 1 @@ -2747,6 +2767,7 @@ def test_execute_evaluation_lazy_loaded_prebuilt_metric_instance( mock_api_client_fixture ) assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] assert summary_metric.metric_name == "safety" @@ -2774,6 +2795,7 @@ def test_execute_evaluation_prebuilt_metric_via_loader( mock_api_client_fixture ) assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] assert summary_metric.metric_name == "safety" @@ -2805,7 +2827,9 @@ def test_execute_evaluation_with_gcs_destination( ) mock_eval_dependencies["mock_upload_to_gcs"].assert_called_once_with( - data=result.model_dump(mode="json", exclude_none=True), + data=result.model_dump( + mode="json", exclude_none=True, exclude={"evaluation_dataset"} + ), gcs_dest_prefix=gcs_dest, filename_prefix="evaluation_result", ) diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index a4e1564f87..c460c64950 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -176,7 +176,7 @@ def _extract_contents_for_inference( def _execute_inference_concurrently( api_client: BaseApiClient, model_or_fn: Union[str, Callable[[Any], Any]], - prompt_dataset: "pd.DataFrame", + prompt_dataset: pd.DataFrame, progress_desc: str, gemini_config: Optional[genai_types.GenerateContentConfig] = None, inference_fn: Optional[Callable[[Any, Any, Any, Any], Any]] = None, @@ -251,7 +251,7 @@ def _execute_inference_concurrently( def _run_gemini_inference( api_client: BaseApiClient, model: str, - prompt_dataset: "pd.DataFrame", + prompt_dataset: pd.DataFrame, config: Optional[genai_types.GenerateContentConfig] = None, ) -> list[Union[genai_types.GenerateContentResponse, dict[str, Any]]]: """Internal helper to run inference using Gemini model with concurrency.""" @@ -533,7 +533,7 @@ def _get_dataset_source( def _resolve_dataset_inputs( - dataset: Union[types.EvaluationDataset, list[types.EvaluationDataset]], + dataset: list[types.EvaluationDataset], dataset_schema: Optional[Literal["gemini", "flatten"]], loader: _evals_utils.EvalDatasetLoader, ) -> tuple[types.EvaluationDataset, int]: @@ -552,19 +552,12 @@ def _resolve_dataset_inputs( evaluation cases. - num_response_candidates: The number of response candidates. """ - num_response_candidates: int - datasets_to_process: list[types.EvaluationDataset] - - if isinstance(dataset, list): - if not dataset: - raise ValueError("Input dataset list cannot be empty.") - num_response_candidates = len(dataset) - datasets_to_process = dataset - logger.info("Processing %s datasets for comparison.", num_response_candidates) - else: - num_response_candidates = 1 - datasets_to_process = [dataset] - logger.info("Processing a single dataset.") + if not dataset: + raise ValueError("Input dataset list cannot be empty.") + + num_response_candidates = len(dataset) + datasets_to_process = dataset + logger.info("Processing %s dataset(s).", num_response_candidates) loaded_raw_datasets: list[list[dict[str, Any]]] = [] schemas_for_merge: list[str] = [] @@ -667,14 +660,40 @@ def _execute_evaluation( dataset_schema: Optional[Literal["gemini", "flatten"]] = None, dest: Optional[str] = None, ) -> types.EvaluationResult: - """Evaluates a dataset using the provided metrics.""" + """Evaluates a dataset using the provided metrics. + + Args: + api_client: The API client. + dataset: The dataset to evaluate. + metrics: The metrics to evaluate the dataset against. + dataset_schema: The schema of the dataset. + dest: The destination to save the evaluation results. + + Returns: + The evaluation result. + """ logger.info("Preparing dataset(s) and metrics...") - loader = _evals_utils.EvalDatasetLoader(api_client=api_client) + if isinstance(dataset, types.EvaluationDataset): + dataset_list = [dataset] + elif isinstance(dataset, list): + for item in dataset: + if not isinstance(item, types.EvaluationDataset): + raise TypeError( + f"Unsupported dataset type: {type(item)}. " + "Must be EvaluationDataset." + ) + dataset_list = dataset + else: + raise TypeError( + f"Unsupported dataset type: {type(dataset)}. Must be an" + " EvaluationDataset or a list of EvaluationDataset." + ) + loader = _evals_utils.EvalDatasetLoader(api_client=api_client) processed_eval_dataset, num_response_candidates = _resolve_dataset_inputs( - dataset=dataset, dataset_schema=dataset_schema, loader=loader + dataset=dataset_list, dataset_schema=dataset_schema, loader=loader ) resolved_metrics = _resolve_metrics(metrics, api_client) @@ -693,13 +712,19 @@ def _execute_evaluation( ) t2 = time.perf_counter() logger.info("Evaluation took: %f seconds", t2 - t1) + + evaluation_result.evaluation_dataset = dataset_list logger.info("Evaluation run completed.") if dest: uploaded_path = _evals_utils.GcsUtils( api_client=api_client ).upload_json_to_prefix( - data=evaluation_result.model_dump(mode="json", exclude_none=True), + data=evaluation_result.model_dump( + mode="json", + exclude_none=True, + exclude={"evaluation_dataset"}, + ), gcs_dest_prefix=dest, filename_prefix="evaluation_result", ) diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py index 093b6e6def..506ea8200b 100644 --- a/vertexai/_genai/evals.py +++ b/vertexai/_genai/evals.py @@ -1240,11 +1240,9 @@ def evaluate( config = types.EvaluateMethodConfig.model_validate(config) if isinstance(dataset, list): dataset = [ - ( - types.EvaluationDataset.model_validate(ds_item) - if isinstance(ds_item, dict) - else ds_item - ) + types.EvaluationDataset.model_validate(ds_item) + if isinstance(ds_item, dict) + else ds_item for ds_item in dataset ] else: diff --git a/vertexai/_genai/types.py b/vertexai/_genai/types.py index 18d588c474..e87965e339 100644 --- a/vertexai/_genai/types.py +++ b/vertexai/_genai/types.py @@ -2176,9 +2176,9 @@ def to_yaml_file(self, file_path: str, version: Optional[str] = None) -> None: exclude_unset=True, exclude_none=True, mode="json", - exclude=( - fields_to_exclude_callables if fields_to_exclude_callables else None - ), + exclude=fields_to_exclude_callables + if fields_to_exclude_callables + else None, ) if version: @@ -3097,6 +3097,10 @@ class EvaluationResult(_common.BaseModel): default=None, description="""A list of summary-level evaluation results for each metric.""", ) + evaluation_dataset: Optional[list[EvaluationDataset]] = Field( + default=None, + description="""The input evaluation dataset(s) for the evaluation run.""", + ) metadata: Optional[EvaluationRunMetadata] = Field( default=None, description="""Metadata for the evaluation run.""" ) @@ -3111,6 +3115,9 @@ class EvaluationResultDict(TypedDict, total=False): summary_metrics: Optional[list[AggregatedMetricResultDict]] """A list of summary-level evaluation results for each metric.""" + evaluation_dataset: Optional[list[EvaluationDatasetDict]] + """The input evaluation dataset(s) for the evaluation run.""" + metadata: Optional[EvaluationRunMetadataDict] """Metadata for the evaluation run.""" From 0059c01b7395fc93be8d9214c938299678f67d3e Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 17 Jun 2025 09:47:58 -0700 Subject: [PATCH 14/24] fix: Update supported python version for create_reasoning_engine PiperOrigin-RevId: 772511618 --- vertexai/reasoning_engines/_reasoning_engines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index 0703c63359..2b9efbe06f 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -224,7 +224,7 @@ def create( use for staging the artifacts needed. sys_version (str): Optional. The Python system version used. Currently supports any - of "3.8", "3.9", "3.10", "3.11", "3.12". If not specified, + of "3.9", "3.10", "3.11", "3.12", "3.13". If not specified, it defaults to the "{major}.{minor}" attributes of sys.version_info. extra_packages (Sequence[str]): From e762008a146578d3706f1f9b2898081e51e705f4 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Tue, 17 Jun 2025 10:29:47 -0700 Subject: [PATCH 15/24] chore: expand support for response output types in multimodal gemini format for evaluation PiperOrigin-RevId: 772528216 --- tests/unit/vertexai/genai/test_evals.py | 32 +++++- vertexai/_genai/_evals_data_converters.py | 121 +++++++++++++++------- 2 files changed, 111 insertions(+), 42 deletions(-) diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 4bf55cc392..0bd23ec7e8 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -92,8 +92,8 @@ def test_eval_batch_eval(self, mock_evaluate, mock_get_api_client): mock_evaluate.assert_called_once() -class TestEvalsClientInference: - """Unit tests for the Evals client inference method.""" +class TestEvalsRunInference: + """Unit tests for the Evals run_inference method.""" def setup_method(self): importlib.reload(aiplatform_initializer) @@ -1229,7 +1229,10 @@ def test_convert_no_candidates_in_response(self): result_dataset = self.converter.convert(raw_data) eval_case = result_dataset.eval_cases[0] assert len(eval_case.responses) == 1 - assert eval_case.responses[0].response is None + assert ( + eval_case.responses[0].response.parts[0].text + == _evals_data_converters._PLACEHOLDER_RESPONSE_TEXT + ) def test_convert_invalid_content_structure_raises_value_error(self): raw_data = [ @@ -1300,6 +1303,29 @@ def test_convert_multiple_items(self): assert result_dataset.eval_cases[0].prompt.parts[0].text == "Item 1" assert result_dataset.eval_cases[1].prompt.parts[0].text == "Item 2" + def test_convert_with_raw_string_response(self): + """Tests conversion when the response is a raw string.""" + raw_data = [ + { + "request": { + "contents": [{"role": "user", "parts": [{"text": "Hello"}]}] + }, + "response": "Hi from a raw string", + } + ] + result_dataset = self.converter.convert(raw_data) + assert isinstance(result_dataset, vertexai_genai_types.EvaluationDataset) + assert len(result_dataset.eval_cases) == 1 + eval_case = result_dataset.eval_cases[0] + + assert eval_case.prompt == genai_types.Content( + parts=[genai_types.Part(text="Hello")], role="user" + ) + assert len(eval_case.responses) == 1 + assert eval_case.responses[0].response == genai_types.Content( + parts=[genai_types.Part(text="Hi from a raw string")], + ) + class TestFlattenEvalDataConverter: """Unit tests for the _FlattenEvalDataConverter class.""" diff --git a/vertexai/_genai/_evals_data_converters.py b/vertexai/_genai/_evals_data_converters.py index ee794bcfd8..cb41f3c461 100644 --- a/vertexai/_genai/_evals_data_converters.py +++ b/vertexai/_genai/_evals_data_converters.py @@ -27,6 +27,17 @@ logger = logging.getLogger("vertexai_genai._evals_data_converters") +_PLACEHOLDER_RESPONSE_TEXT = "Error: Missing response for this candidate" + + +def _create_placeholder_response_candidate( + text: str = _PLACEHOLDER_RESPONSE_TEXT, +) -> types.ResponseCandidate: + """Creates a ResponseCandidate with placeholder text.""" + return types.ResponseCandidate( + response=genai_types.Content(parts=[genai_types.Part(text=text)]) + ) + class EvalDatasetSchema(_common.CaseInSensitiveEnum): """Represents the schema of an evaluation dataset.""" @@ -115,23 +126,40 @@ def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: reference, ) = self._parse_request(request_data) - generate_content_response = ( - genai_types.GenerateContentResponse.model_validate(response_data) - ) - responses = [] - if generate_content_response.candidates: - candidate = generate_content_response.candidates[0] - if candidate.content: - responses.append( - types.ResponseCandidate( - response=genai_types.Content.model_validate( - candidate.content - ) + if isinstance(response_data, str): + responses.append( + types.ResponseCandidate( + response=genai_types.Content( + parts=[genai_types.Part(text=response_data)] ) ) - else: # Handle cases where there are no candidates (e.g., prompt blocked) - responses.append(types.ResponseCandidate(response=None)) + ) + elif isinstance(response_data, dict): + try: + generate_content_response = ( + genai_types.GenerateContentResponse.model_validate( + response_data + ) + ) + if generate_content_response.candidates: + candidate = generate_content_response.candidates[0] + if candidate.content: + responses.append( + types.ResponseCandidate( + response=genai_types.Content.model_validate( + candidate.content + ) + ) + ) + else: # Handle cases where there are no candidates. + responses.append(_create_placeholder_response_candidate()) + except Exception: + # Fallback for dicts that don't match the schema, treat as empty. + responses.append(_create_placeholder_response_candidate()) + else: + # For any other type, treat as an empty/invalid response. + responses.append(_create_placeholder_response_candidate()) eval_case = types.EvalCase( eval_case_id=eval_case_id, @@ -268,29 +296,56 @@ def auto_detect_dataset_schema( ) -> EvalDatasetSchema: """Detects the schema of a raw dataset.""" if not raw_dataset: + logger.debug("Empty dataset, returning UNKNOWN schema.") return EvalDatasetSchema.UNKNOWN first_item = raw_dataset[0] - try: - _GeminiEvalDataConverter().convert([first_item]) - return EvalDatasetSchema.GEMINI - except (ValueError, KeyError, AttributeError, TypeError) as e: - logger.debug( - "First item not parsable as Gemini schema (error: %s), " - "falling back to key-based checks.", - e, + if not isinstance(first_item, dict): + logger.warning( + "First item in dataset is not a dictionary. Cannot determine schema." ) - pass + return EvalDatasetSchema.UNKNOWN - # Fallback to key-based detection for flatten schema keys = set(first_item.keys()) + + request_field = first_item.get("request") + if isinstance(request_field, dict) and isinstance( + request_field.get("contents"), list + ): + try: + _GeminiEvalDataConverter().convert([first_item]) + logger.debug( + "Detected GEMINI schema based on 'request.contents' presence and" + " successful conversion." + ) + return EvalDatasetSchema.GEMINI + except (ValueError, KeyError, AttributeError, TypeError) as e: + logger.debug( + "First item looked like Gemini schema (due to 'request.contents') but" + " conversion failed (error: %s). Will try other schemas.", + e, + ) + + # Check for flatten schema if Gemini check failed or wasn't applicable if {"prompt", "response"}.issubset(keys) or { "response", "reference", }.issubset(keys): - return EvalDatasetSchema.FLATTEN - else: - return EvalDatasetSchema.UNKNOWN + try: + _FlattenEvalDataConverter().convert([first_item]) + logger.debug( + "Detected FLATTEN schema based on key presence and successful" + " conversion." + ) + return EvalDatasetSchema.FLATTEN + except (ValueError, KeyError, AttributeError, TypeError) as e: + logger.debug( + "Flatten schema key check passed, but conversion failed (error: %s).", + e, + ) + + logger.debug("Could not confidently determine schema. Returning UNKNOWN.") + return EvalDatasetSchema.UNKNOWN _SCHEMA_TO_CONVERTER = { @@ -332,18 +387,6 @@ def _get_text_from_reference( return None -_PLACEHOLDER_RESPONSE_TEXT = "Error: Missing response for this candidate" - - -def _create_placeholder_response_candidate( - text: str = _PLACEHOLDER_RESPONSE_TEXT, -) -> types.ResponseCandidate: - """Creates a ResponseCandidate with placeholder text.""" - return types.ResponseCandidate( - response=genai_types.Content(parts=[genai_types.Part(text=text)]) - ) - - def _validate_case_consistency( base_case: types.EvalCase, current_case: types.EvalCase, From 48f2d7476afc4f629657cad7b0c551f122a59b84 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Tue, 17 Jun 2025 11:12:16 -0700 Subject: [PATCH 16/24] chore: move SDK info back to top of readme PiperOrigin-RevId: 772546842 --- README.rst | 45 ++++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/README.rst b/README.rst index 425931cf0a..9ef50aabb1 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,28 @@ Vertex AI SDK for Python ================================================= +|GA| |pypi| |versions| |unit-tests| |system-tests| |sample-tests| + +`Vertex AI`_: Google Vertex AI is an integrated suite of machine learning tools and services for building and using ML models with AutoML or custom code. It offers both novices and experts the best workbench for the entire machine learning development lifecycle. + +- `Client Library Documentation`_ +- `Product Documentation`_ + +.. |GA| image:: https://img.shields.io/badge/support-ga-gold.svg + :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#general-availability +.. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-aiplatform.svg + :target: https://pypi.org/project/google-cloud-aiplatform/ +.. |versions| image:: https://img.shields.io/pypi/pyversions/google-cloud-aiplatform.svg + :target: https://pypi.org/project/google-cloud-aiplatform/ +.. |unit-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-unit-tests.svg + :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-unit-tests.html +.. |system-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-system-tests.svg + :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-system-tests.html +.. |sample-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-sample-tests.svg + :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-sample-tests.html +.. _Vertex AI: https://cloud.google.com/vertex-ai/docs +.. _Client Library Documentation: https://cloud.google.com/python/docs/reference/aiplatform/latest +.. _Product Documentation: https://cloud.google.com/vertex-ai/docs Gemini API and Generative AI on Vertex AI ----------------------------------------- @@ -71,29 +93,6 @@ Then run evaluation by providing the inference results and specifying the metric ----------------------------------------- -|GA| |pypi| |versions| |unit-tests| |system-tests| |sample-tests| - -`Vertex AI`_: Google Vertex AI is an integrated suite of machine learning tools and services for building and using ML models with AutoML or custom code. It offers both novices and experts the best workbench for the entire machine learning development lifecycle. - -- `Client Library Documentation`_ -- `Product Documentation`_ - -.. |GA| image:: https://img.shields.io/badge/support-ga-gold.svg - :target: https://github.com/googleapis/google-cloud-python/blob/main/README.rst#general-availability -.. |pypi| image:: https://img.shields.io/pypi/v/google-cloud-aiplatform.svg - :target: https://pypi.org/project/google-cloud-aiplatform/ -.. |versions| image:: https://img.shields.io/pypi/pyversions/google-cloud-aiplatform.svg - :target: https://pypi.org/project/google-cloud-aiplatform/ -.. |unit-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-unit-tests.svg - :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-unit-tests.html -.. |system-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-system-tests.svg - :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-system-tests.html -.. |sample-tests| image:: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-sample-tests.svg - :target: https://storage.googleapis.com/cloud-devrel-public/python-aiplatform/badges/sdk-sample-tests.html -.. _Vertex AI: https://cloud.google.com/vertex-ai/docs -.. _Client Library Documentation: https://cloud.google.com/python/docs/reference/aiplatform/latest -.. _Product Documentation: https://cloud.google.com/vertex-ai/docs - Quick Start ----------- From 865a68c1273aa4e4e946a203bf226b80a723523f Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 17 Jun 2025 11:15:47 -0700 Subject: [PATCH 17/24] feat: Update v1 `create_corpus` to accept `encryption_spec` in `rag_data.py` PiperOrigin-RevId: 772548087 --- tests/unit/vertex_rag/test_rag_constants.py | 20 +++++++++++++++++ tests/unit/vertex_rag/test_rag_data.py | 24 ++++++++++++++++++++ vertexai/rag/rag_data.py | 9 ++++++++ vertexai/rag/utils/_gapic_utils.py | 25 +++++++++++++++++++++ vertexai/rag/utils/resources.py | 3 +++ 5 files changed, 81 insertions(+) diff --git a/tests/unit/vertex_rag/test_rag_constants.py b/tests/unit/vertex_rag/test_rag_constants.py index 73b037cdd4..f99e042b7b 100644 --- a/tests/unit/vertex_rag/test_rag_constants.py +++ b/tests/unit/vertex_rag/test_rag_constants.py @@ -62,6 +62,7 @@ VertexAiSearchConfig as GapicVertexAiSearchConfig, ) from google.cloud.aiplatform_v1.types import api_auth +from google.cloud.aiplatform_v1.types import EncryptionSpec from google.protobuf import timestamp_pb2 from google.cloud.aiplatform_v1.types import content @@ -83,6 +84,9 @@ TEST_WEAVIATE_API_KEY_SECRET_VERSION = ( "projects/test-project/secrets/test-secret/versions/1" ) +TEST_ENCRYPTION_SPEC = EncryptionSpec( + kms_key_name="projects/test-project/locations/us-central1/keyRings/test-key-ring/cryptoKeys/test-key" +) TEST_PINECONE_INDEX_NAME = "test-pinecone-index" TEST_PINECONE_API_KEY_SECRET_VERSION = ( "projects/test-project/secrets/test-secret/versions/1" @@ -106,6 +110,14 @@ TEST_GAPIC_RAG_CORPUS.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = "projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format( TEST_PROJECT, TEST_REGION ) +TEST_GAPIC_CMEK_RAG_CORPUS = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + encryption_spec=EncryptionSpec( + kms_key_name="projects/test-project/locations/us-central1/keyRings/test-key-ring/cryptoKeys/test-key" + ), +) TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH = GapicRagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, @@ -148,6 +160,14 @@ display_name=TEST_CORPUS_DISPLAY_NAME, backend_config=TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, ) +TEST_CMEK_RAG_CORPUS = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + encryption_spec=EncryptionSpec( + kms_key_name="projects/test-project/locations/us-central1/keyRings/test-key-ring/cryptoKeys/test-key" + ), +) TEST_BACKEND_CONFIG_PINECONE_CONFIG = RagVectorDbConfig( vector_db=TEST_PINECONE_CONFIG, ) diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index c0e3d2d1cf..da50f63145 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -53,6 +53,21 @@ def create_rag_corpus_mock(): yield create_rag_corpus_mock +@pytest.fixture +def create_rag_corpus_mock_cmek(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_cmek: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants.TEST_GAPIC_CMEK_RAG_CORPUS + ) + create_rag_corpus_mock_cmek.return_value = create_rag_corpus_lro_mock + yield create_rag_corpus_mock_cmek + + @pytest.fixture def create_rag_corpus_mock_vertex_vector_search(): with mock.patch.object( @@ -373,6 +388,15 @@ def test_create_corpus_vertex_vector_search_success(self): rag_corpus, test_rag_constants.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH ) + @pytest.mark.usefixtures("create_rag_corpus_mock_cmek") + def test_create_corpus_cmek_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants.TEST_CORPUS_DISPLAY_NAME, + encryption_spec=test_rag_constants.TEST_ENCRYPTION_SPEC, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants.TEST_CMEK_RAG_CORPUS) + @pytest.mark.usefixtures("create_rag_corpus_mock_pinecone") def test_create_corpus_pinecone_success(self): rag_corpus = rag.create_corpus( diff --git a/vertexai/rag/rag_data.py b/vertexai/rag/rag_data.py index ec684b0969..0c5fc3975e 100644 --- a/vertexai/rag/rag_data.py +++ b/vertexai/rag/rag_data.py @@ -46,6 +46,7 @@ from vertexai.rag.utils import ( _gapic_utils, ) +from google.cloud.aiplatform_v1.types import EncryptionSpec from vertexai.rag.utils.resources import ( JiraSource, LayoutParserConfig, @@ -71,6 +72,7 @@ def create_corpus( None, ] ] = None, + encryption_spec: Optional[EncryptionSpec] = None, ) -> RagCorpus: """Creates a new RagCorpus resource. @@ -96,6 +98,7 @@ def create_corpus( specified. backend_config: The backend config of the RagCorpus, specifying a data store and/or embedding model. + encryption_spec: The encryption spec of the RagCorpus. Returns: RagCorpus. Raises: @@ -124,6 +127,12 @@ def create_corpus( rag_corpus=rag_corpus, ) + if encryption_spec: + _gapic_utils.set_encryption_spec( + encryption_spec=encryption_spec, + rag_corpus=rag_corpus, + ) + request = CreateRagCorpusRequest( parent=parent, rag_corpus=rag_corpus, diff --git a/vertexai/rag/utils/_gapic_utils.py b/vertexai/rag/utils/_gapic_utils.py index e98d082f96..997131f7fc 100644 --- a/vertexai/rag/utils/_gapic_utils.py +++ b/vertexai/rag/utils/_gapic_utils.py @@ -17,6 +17,7 @@ import re from typing import Any, Dict, Optional, Sequence, Union from google.cloud.aiplatform_v1.types import api_auth +from google.cloud.aiplatform_v1.types import EncryptionSpec from google.cloud.aiplatform_v1 import ( RagEmbeddingModelConfig as GapicRagEmbeddingModelConfig, GoogleDriveSource, @@ -203,6 +204,7 @@ def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: backend_config=convert_gapic_to_backend_config( gapic_rag_corpus.vector_db_config ), + encryption_spec=gapic_rag_corpus.encryption_spec, ) return rag_corpus @@ -223,6 +225,7 @@ def convert_gapic_to_rag_corpus_no_embedding_model_config( backend_config=convert_gapic_to_backend_config( rag_vector_db_config_no_embedding_model_config ), + encryption_spec=gapic_rag_corpus.encryption_spec, ) return rag_corpus @@ -660,6 +663,28 @@ def set_backend_config( ) +def set_encryption_spec( + encryption_spec: EncryptionSpec, + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the encryption spec for the rag corpus.""" + # Raises value error if encryption_spec.kms_key_name is None or empty, + if encryption_spec.kms_key_name is None or not encryption_spec.kms_key_name: + raise ValueError("kms_key_name must be set if encryption_spec is set.") + + # Raises value error if encryption_spec.kms_key_name is not a valid KMS key name. + if not re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/keyRings/(?P.+?)/cryptoKeys/(?P.+?)$", + encryption_spec.kms_key_name, + ): + raise ValueError( + "kms_key_name must be of the format " + "`projects/{project}/locations/{location}/keyRings/{key_ring}/cryptoKeys/{crypto_key}`" + ) + + rag_corpus.encryption_spec = encryption_spec + + def set_vertex_ai_search_config( vertex_ai_search_config: VertexAiSearchConfig, rag_corpus: GapicRagCorpus, diff --git a/vertexai/rag/utils/resources.py b/vertexai/rag/utils/resources.py index 407ff80f4b..6f65583f49 100644 --- a/vertexai/rag/utils/resources.py +++ b/vertexai/rag/utils/resources.py @@ -19,6 +19,7 @@ from typing import List, Optional, Sequence, Union from google.protobuf import timestamp_pb2 +from google.cloud.aiplatform_v1.types import EncryptionSpec @dataclasses.dataclass @@ -190,6 +191,7 @@ class RagCorpus: vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. backend_config: The backend config of the RagCorpus. It can be a data store and/or retrieval engine. + encryption_spec: The encryption spec of the RagCorpus. Immutable. """ name: Optional[str] = None @@ -202,6 +204,7 @@ class RagCorpus: None, ] ] = None + encryption_spec: Optional[EncryptionSpec] = None @dataclasses.dataclass From f1c8c2f50017325e4417045eea24b7b800789fb5 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Tue, 17 Jun 2025 12:23:15 -0700 Subject: [PATCH 18/24] chore: enable 'from vertexai.types import *' for GenAI client PiperOrigin-RevId: 772574048 --- tests/unit/vertexai/genai/test_genai_client.py | 5 +++++ vertexai/__init__.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/tests/unit/vertexai/genai/test_genai_client.py b/tests/unit/vertexai/genai/test_genai_client.py index 94f465f809..9bf71cd9f6 100644 --- a/tests/unit/vertexai/genai/test_genai_client.py +++ b/tests/unit/vertexai/genai/test_genai_client.py @@ -55,3 +55,8 @@ def test_genai_client(self): async def test_async_client(self): test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) assert isinstance(test_client.aio, vertexai._genai.client.AsyncClient) + + @pytest.mark.usefixtures("google_auth_mock") + def test_types(self): + assert vertexai.types is not None + assert vertexai.types.LLMMetric is not None diff --git a/vertexai/__init__.py b/vertexai/__init__.py index 2e8a2f2fd9..1444612f49 100644 --- a/vertexai/__init__.py +++ b/vertexai/__init__.py @@ -15,6 +15,7 @@ """The vertexai module.""" import importlib +import sys from google.cloud.aiplatform import version as aiplatform_version @@ -47,6 +48,8 @@ def __getattr__(name): global _genai_types if _genai_types is None: _genai_types = importlib.import_module("._genai.types", __name__) + if "vertexai.types" not in sys.modules: + sys.modules["vertexai.types"] = _genai_types return _genai_types raise AttributeError(f"module '{__name__}' has no attribute '{name}'") From c5bb99b80dbbc76ababdba1228154717370eb5dd Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Tue, 17 Jun 2025 14:26:05 -0700 Subject: [PATCH 19/24] Copybara import of the project: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit -- adcd611ba29acef5c6a3270cd4edae1bbcb616ea by Owl Bot : feat: add DnsPeeringConfig in service_networking.proto feat: add dns_peering_configs to PscInterfaceConfig PiperOrigin-RevId: 772111120 Source-Link: https://github.com/googleapis/googleapis/commit/09ba8d83538db44efe3b81eb46bbdaf4828771b7 Source-Link: https://github.com/googleapis/googleapis-gen/commit/2e15140cd1d085f5c2fc1d15ab1acc6d6ca1fca0 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiMmUxNTE0MGNkMWQwODVmNWMyZmMxZDE1YWIxYWNjNmQ2Y2ExZmNhMCJ9 -- 32f98ee9947279b5463661ecddf4b71b7699956d by Owl Bot : 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md -- 9b7365e7cc5520754ed0d49c9f732171ee609648 by Owl Bot : feat: add DnsPeeringConfig in service_networking.proto feat: add dns_peering_configs to PscInterfaceConfig PiperOrigin-RevId: 772146251 Source-Link: https://github.com/googleapis/googleapis/commit/7bdfcb7a08465f1292b9192990cc85382ee4e316 Source-Link: https://github.com/googleapis/googleapis-gen/commit/000e27a2b7392520f668ca2535db233d0be60b49 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiMDAwZTI3YTJiNzM5MjUyMGY2NjhjYTI1MzVkYjIzM2QwYmU2MGI0OSJ9 -- fab86e027de35ce8d7531d1f446de3ab4ef53120 by Owl Bot : 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md -- 47833a58e6b9f12864dda7e5b9293df1f98e0d60 by Owl Bot : feat: add RagEngineConfig update/get APIs to v1 feat: add Unprovisioned tier to RagEngineConfig to disable RagEngine service and delete all data within the service. PiperOrigin-RevId: 772174333 Source-Link: https://github.com/googleapis/googleapis/commit/d04f530e3d5d4b7e5b928d4198a8b78b8f4c352d Source-Link: https://github.com/googleapis/googleapis-gen/commit/893c5b9c68fe8aa80aae28242ff16e01396b22ad Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiODkzYzViOWM2OGZlOGFhODBhYWUyODI0MmZmMTZlMDEzOTZiMjJhZCJ9 -- bfaeefc36c0d0380b6c313ed7e1b3936bbe89011 by Owl Bot : 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md -- 4796c1912d74b2a7da0b11915cb7e20640cb2586 by Owl Bot : feat: add Scaled tier for RagEngineConfig to v1beta, equivalent to Enterprise feat: add Unprovisioned tier to RagEngineConfig in v1beta1 that can disable RagEngine service and delete all data within the service docs: Enterprise tier in RagEngineConfig, use Scaled tier instead. PiperOrigin-RevId: 772188314 Source-Link: https://github.com/googleapis/googleapis/commit/67a660f1bdc194cc043ffb91f918ee126117208f Source-Link: https://github.com/googleapis/googleapis-gen/commit/1db88b7e3709ecd23fd1e109ffc300c997b501fb Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiMWRiODhiN2UzNzA5ZWNkMjNmZDFlMTA5ZmZjMzAwYzk5N2I1MDFmYiJ9 -- e7a5ad0879a013fcae8c7efe7d41acfe5d6671e0 by Owl Bot : 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/5434 from googleapis:owl-bot-copy a824c3cb4d3d39d47a238846b92860dbb0c8fad0 PiperOrigin-RevId: 772620485 --- google/cloud/aiplatform_v1/__init__.py | 12 + .../cloud/aiplatform_v1/gapic_metadata.json | 30 + .../vertex_rag_data_service/async_client.py | 253 + .../vertex_rag_data_service/client.py | 261 + .../transports/base.py | 30 + .../transports/grpc.py | 57 + .../transports/grpc_asyncio.py | 68 + .../transports/rest.py | 446 ++ .../transports/rest_asyncio.py | 465 ++ .../transports/rest_base.py | 106 + google/cloud/aiplatform_v1/types/__init__.py | 12 + .../aiplatform_v1/types/service_networking.py | 49 + .../aiplatform_v1/types/vertex_rag_data.py | 102 + .../types/vertex_rag_data_service.py | 55 + google/cloud/aiplatform_v1beta1/__init__.py | 2 + .../aiplatform_v1beta1/types/__init__.py | 2 + .../types/service_networking.py | 49 + .../types/vertex_rag_data.py | 54 +- ...ata_service_get_rag_engine_config_async.py | 52 + ...data_service_get_rag_engine_config_sync.py | 52 + ..._service_update_rag_engine_config_async.py | 55 + ...a_service_update_rag_engine_config_sync.py | 55 + ...t_metadata_google.cloud.aiplatform.v1.json | 322 ++ .../gapic/aiplatform_v1/test_job_service.py | 62 +- .../test_persistent_resource_service.py | 44 +- .../aiplatform_v1/test_pipeline_service.py | 22 +- .../aiplatform_v1/test_schedule_service.py | 36 +- .../test_vertex_rag_data_service.py | 4639 ++++++++++++----- .../aiplatform_v1beta1/test_job_service.py | 62 +- .../test_persistent_resource_service.py | 44 +- .../test_pipeline_service.py | 22 +- .../test_schedule_service.py | 36 +- .../test_vertex_rag_data_service.py | 14 +- 33 files changed, 6203 insertions(+), 1367 deletions(-) create mode 100644 samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_sync.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_async.py create mode 100644 samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_sync.py diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py index c5dd5ec66b..9527892e71 100644 --- a/google/cloud/aiplatform_v1/__init__.py +++ b/google/cloud/aiplatform_v1/__init__.py @@ -858,6 +858,7 @@ from .types.schedule_service import PauseScheduleRequest from .types.schedule_service import ResumeScheduleRequest from .types.schedule_service import UpdateScheduleRequest +from .types.service_networking import DnsPeeringConfig from .types.service_networking import PrivateServiceConnectConfig from .types.service_networking import PscAutomatedEndpoints from .types.service_networking import PSCAutomationConfig @@ -980,10 +981,12 @@ from .types.vertex_rag_data import RagChunk from .types.vertex_rag_data import RagCorpus from .types.vertex_rag_data import RagEmbeddingModelConfig +from .types.vertex_rag_data import RagEngineConfig from .types.vertex_rag_data import RagFile from .types.vertex_rag_data import RagFileChunkingConfig from .types.vertex_rag_data import RagFileParsingConfig from .types.vertex_rag_data import RagFileTransformationConfig +from .types.vertex_rag_data import RagManagedDbConfig from .types.vertex_rag_data import RagVectorDbConfig from .types.vertex_rag_data import UploadRagFileConfig from .types.vertex_rag_data import VertexAiSearchConfig @@ -992,6 +995,7 @@ from .types.vertex_rag_data_service import DeleteRagCorpusRequest from .types.vertex_rag_data_service import DeleteRagFileRequest from .types.vertex_rag_data_service import GetRagCorpusRequest +from .types.vertex_rag_data_service import GetRagEngineConfigRequest from .types.vertex_rag_data_service import GetRagFileRequest from .types.vertex_rag_data_service import ImportRagFilesOperationMetadata from .types.vertex_rag_data_service import ImportRagFilesRequest @@ -1002,6 +1006,8 @@ from .types.vertex_rag_data_service import ListRagFilesResponse from .types.vertex_rag_data_service import UpdateRagCorpusOperationMetadata from .types.vertex_rag_data_service import UpdateRagCorpusRequest +from .types.vertex_rag_data_service import UpdateRagEngineConfigOperationMetadata +from .types.vertex_rag_data_service import UpdateRagEngineConfigRequest from .types.vertex_rag_data_service import UploadRagFileRequest from .types.vertex_rag_data_service import UploadRagFileResponse from .types.vertex_rag_service import AugmentPromptRequest @@ -1310,6 +1316,7 @@ "DirectRawPredictResponse", "DirectUploadSource", "DiskSpec", + "DnsPeeringConfig", "DoubleArray", "DynamicRetrievalConfig", "EncryptionSpec", @@ -1446,6 +1453,7 @@ "GetPipelineJobRequest", "GetPublisherModelRequest", "GetRagCorpusRequest", + "GetRagEngineConfigRequest", "GetRagFileRequest", "GetReasoningEngineRequest", "GetScheduleRequest", @@ -1754,10 +1762,12 @@ "RagContexts", "RagCorpus", "RagEmbeddingModelConfig", + "RagEngineConfig", "RagFile", "RagFileChunkingConfig", "RagFileParsingConfig", "RagFileTransformationConfig", + "RagManagedDbConfig", "RagQuery", "RagRetrievalConfig", "RagVectorDbConfig", @@ -1981,6 +1991,8 @@ "UpdatePersistentResourceRequest", "UpdateRagCorpusOperationMetadata", "UpdateRagCorpusRequest", + "UpdateRagEngineConfigOperationMetadata", + "UpdateRagEngineConfigRequest", "UpdateReasoningEngineOperationMetadata", "UpdateReasoningEngineRequest", "UpdateScheduleRequest", diff --git a/google/cloud/aiplatform_v1/gapic_metadata.json b/google/cloud/aiplatform_v1/gapic_metadata.json index 170e56f293..7db7b7c9b2 100644 --- a/google/cloud/aiplatform_v1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1/gapic_metadata.json @@ -5066,6 +5066,11 @@ "get_rag_corpus" ] }, + "GetRagEngineConfig": { + "methods": [ + "get_rag_engine_config" + ] + }, "GetRagFile": { "methods": [ "get_rag_file" @@ -5091,6 +5096,11 @@ "update_rag_corpus" ] }, + "UpdateRagEngineConfig": { + "methods": [ + "update_rag_engine_config" + ] + }, "UploadRagFile": { "methods": [ "upload_rag_file" @@ -5121,6 +5131,11 @@ "get_rag_corpus" ] }, + "GetRagEngineConfig": { + "methods": [ + "get_rag_engine_config" + ] + }, "GetRagFile": { "methods": [ "get_rag_file" @@ -5146,6 +5161,11 @@ "update_rag_corpus" ] }, + "UpdateRagEngineConfig": { + "methods": [ + "update_rag_engine_config" + ] + }, "UploadRagFile": { "methods": [ "upload_rag_file" @@ -5176,6 +5196,11 @@ "get_rag_corpus" ] }, + "GetRagEngineConfig": { + "methods": [ + "get_rag_engine_config" + ] + }, "GetRagFile": { "methods": [ "get_rag_file" @@ -5201,6 +5226,11 @@ "update_rag_corpus" ] }, + "UpdateRagEngineConfig": { + "methods": [ + "update_rag_engine_config" + ] + }, "UploadRagFile": { "methods": [ "upload_rag_file" diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/async_client.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/async_client.py index e0213cf348..73869171f7 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/async_client.py @@ -94,6 +94,12 @@ class VertexRagDataServiceAsyncClient: parse_rag_corpus_path = staticmethod( VertexRagDataServiceClient.parse_rag_corpus_path ) + rag_engine_config_path = staticmethod( + VertexRagDataServiceClient.rag_engine_config_path + ) + parse_rag_engine_config_path = staticmethod( + VertexRagDataServiceClient.parse_rag_engine_config_path + ) rag_file_path = staticmethod(VertexRagDataServiceClient.rag_file_path) parse_rag_file_path = staticmethod(VertexRagDataServiceClient.parse_rag_file_path) secret_version_path = staticmethod(VertexRagDataServiceClient.secret_version_path) @@ -1629,6 +1635,253 @@ async def sample_delete_rag_file(): # Done; return the response. return response + async def update_rag_engine_config( + self, + request: Optional[ + Union[vertex_rag_data_service.UpdateRagEngineConfigRequest, dict] + ] = None, + *, + rag_engine_config: Optional[vertex_rag_data.RagEngineConfig] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates a RagEngineConfig. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_update_rag_engine_config(): + # Create a client + client = aiplatform_v1.VertexRagDataServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.UpdateRagEngineConfigRequest( + ) + + # Make the request + operation = client.update_rag_engine_config(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.UpdateRagEngineConfigRequest, dict]]): + The request object. Request message for + [VertexRagDataService.UpdateRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagEngineConfig]. + rag_engine_config (:class:`google.cloud.aiplatform_v1.types.RagEngineConfig`): + Required. The updated + RagEngineConfig. + NOTE: Downgrading your RagManagedDb's + ComputeTier could temporarily increase + request latencies until the operation is + fully complete. + + This corresponds to the ``rag_engine_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.RagEngineConfig` + Config for RagEngine. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + flattened_params = [rag_engine_config] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, vertex_rag_data_service.UpdateRagEngineConfigRequest + ): + request = vertex_rag_data_service.UpdateRagEngineConfigRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if rag_engine_config is not None: + request.rag_engine_config = rag_engine_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_rag_engine_config + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("rag_engine_config.name", request.rag_engine_config.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + vertex_rag_data.RagEngineConfig, + metadata_type=vertex_rag_data_service.UpdateRagEngineConfigOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_rag_engine_config( + self, + request: Optional[ + Union[vertex_rag_data_service.GetRagEngineConfigRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> vertex_rag_data.RagEngineConfig: + r"""Gets a RagEngineConfig. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + async def sample_get_rag_engine_config(): + # Create a client + client = aiplatform_v1.VertexRagDataServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.GetRagEngineConfigRequest( + name="name_value", + ) + + # Make the request + response = await client.get_rag_engine_config(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.cloud.aiplatform_v1.types.GetRagEngineConfigRequest, dict]]): + The request object. Request message for + [VertexRagDataService.GetRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.GetRagEngineConfig] + name (:class:`str`): + Required. The name of the RagEngineConfig resource. + Format: + ``projects/{project}/locations/{location}/ragEngineConfig`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.cloud.aiplatform_v1.types.RagEngineConfig: + Config for RagEngine. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + flattened_params = [name] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.GetRagEngineConfigRequest): + request = vertex_rag_data_service.GetRagEngineConfigRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_rag_engine_config + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + async def list_operations( self, request: Optional[operations_pb2.ListOperationsRequest] = None, diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/client.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/client.py index ee407688da..c12844014f 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/client.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/client.py @@ -288,6 +288,26 @@ def parse_rag_corpus_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def rag_engine_config_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified rag_engine_config string.""" + return "projects/{project}/locations/{location}/ragEngineConfig".format( + project=project, + location=location, + ) + + @staticmethod + def parse_rag_engine_config_path(path: str) -> Dict[str, str]: + """Parses a rag_engine_config path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/ragEngineConfig$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def rag_file_path( project: str, @@ -2139,6 +2159,247 @@ def sample_delete_rag_file(): # Done; return the response. return response + def update_rag_engine_config( + self, + request: Optional[ + Union[vertex_rag_data_service.UpdateRagEngineConfigRequest, dict] + ] = None, + *, + rag_engine_config: Optional[vertex_rag_data.RagEngineConfig] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gac_operation.Operation: + r"""Updates a RagEngineConfig. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_update_rag_engine_config(): + # Create a client + client = aiplatform_v1.VertexRagDataServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.UpdateRagEngineConfigRequest( + ) + + # Make the request + operation = client.update_rag_engine_config(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.UpdateRagEngineConfigRequest, dict]): + The request object. Request message for + [VertexRagDataService.UpdateRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagEngineConfig]. + rag_engine_config (google.cloud.aiplatform_v1.types.RagEngineConfig): + Required. The updated + RagEngineConfig. + NOTE: Downgrading your RagManagedDb's + ComputeTier could temporarily increase + request latencies until the operation is + fully complete. + + This corresponds to the ``rag_engine_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.RagEngineConfig` + Config for RagEngine. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + flattened_params = [rag_engine_config] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance( + request, vertex_rag_data_service.UpdateRagEngineConfigRequest + ): + request = vertex_rag_data_service.UpdateRagEngineConfigRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if rag_engine_config is not None: + request.rag_engine_config = rag_engine_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_rag_engine_config] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("rag_engine_config.name", request.rag_engine_config.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = gac_operation.from_gapic( + response, + self._transport.operations_client, + vertex_rag_data.RagEngineConfig, + metadata_type=vertex_rag_data_service.UpdateRagEngineConfigOperationMetadata, + ) + + # Done; return the response. + return response + + def get_rag_engine_config( + self, + request: Optional[ + Union[vertex_rag_data_service.GetRagEngineConfigRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> vertex_rag_data.RagEngineConfig: + r"""Gets a RagEngineConfig. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1 + + def sample_get_rag_engine_config(): + # Create a client + client = aiplatform_v1.VertexRagDataServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.GetRagEngineConfigRequest( + name="name_value", + ) + + # Make the request + response = client.get_rag_engine_config(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1.types.GetRagEngineConfigRequest, dict]): + The request object. Request message for + [VertexRagDataService.GetRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.GetRagEngineConfig] + name (str): + Required. The name of the RagEngineConfig resource. + Format: + ``projects/{project}/locations/{location}/ragEngineConfig`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.cloud.aiplatform_v1.types.RagEngineConfig: + Config for RagEngine. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + flattened_params = [name] + has_flattened_params = ( + len([param for param in flattened_params if param is not None]) > 0 + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, vertex_rag_data_service.GetRagEngineConfigRequest): + request = vertex_rag_data_service.GetRagEngineConfigRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_rag_engine_config] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "VertexRagDataServiceClient": return self diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/base.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/base.py index cac9fda502..4739d84b7b 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/base.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/base.py @@ -189,6 +189,16 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.update_rag_engine_config: gapic_v1.method.wrap_method( + self.update_rag_engine_config, + default_timeout=None, + client_info=client_info, + ), + self.get_rag_engine_config: gapic_v1.method.wrap_method( + self.get_rag_engine_config, + default_timeout=None, + client_info=client_info, + ), self.get_location: gapic_v1.method.wrap_method( self.get_location, default_timeout=None, @@ -354,6 +364,26 @@ def delete_rag_file( ]: raise NotImplementedError() + @property + def update_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagEngineConfigRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def get_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.GetRagEngineConfigRequest], + Union[ + vertex_rag_data.RagEngineConfig, Awaitable[vertex_rag_data.RagEngineConfig] + ], + ]: + raise NotImplementedError() + @property def list_operations( self, diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/grpc.py index adf1675f57..65f10ed8af 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/grpc.py @@ -625,6 +625,63 @@ def delete_rag_file( ) return self._stubs["delete_rag_file"] + @property + def update_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagEngineConfigRequest], operations_pb2.Operation + ]: + r"""Return a callable for the update rag engine config method over gRPC. + + Updates a RagEngineConfig. + + Returns: + Callable[[~.UpdateRagEngineConfigRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_rag_engine_config" not in self._stubs: + self._stubs["update_rag_engine_config"] = self._logged_channel.unary_unary( + "/google.cloud.aiplatform.v1.VertexRagDataService/UpdateRagEngineConfig", + request_serializer=vertex_rag_data_service.UpdateRagEngineConfigRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["update_rag_engine_config"] + + @property + def get_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.GetRagEngineConfigRequest], + vertex_rag_data.RagEngineConfig, + ]: + r"""Return a callable for the get rag engine config method over gRPC. + + Gets a RagEngineConfig. + + Returns: + Callable[[~.GetRagEngineConfigRequest], + ~.RagEngineConfig]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_rag_engine_config" not in self._stubs: + self._stubs["get_rag_engine_config"] = self._logged_channel.unary_unary( + "/google.cloud.aiplatform.v1.VertexRagDataService/GetRagEngineConfig", + request_serializer=vertex_rag_data_service.GetRagEngineConfigRequest.serialize, + response_deserializer=vertex_rag_data.RagEngineConfig.deserialize, + ) + return self._stubs["get_rag_engine_config"] + def close(self): self._logged_channel.close() diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/grpc_asyncio.py index 3c7d95d80e..e47ec6e5f6 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/grpc_asyncio.py @@ -641,6 +641,64 @@ def delete_rag_file( ) return self._stubs["delete_rag_file"] + @property + def update_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagEngineConfigRequest], + Awaitable[operations_pb2.Operation], + ]: + r"""Return a callable for the update rag engine config method over gRPC. + + Updates a RagEngineConfig. + + Returns: + Callable[[~.UpdateRagEngineConfigRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_rag_engine_config" not in self._stubs: + self._stubs["update_rag_engine_config"] = self._logged_channel.unary_unary( + "/google.cloud.aiplatform.v1.VertexRagDataService/UpdateRagEngineConfig", + request_serializer=vertex_rag_data_service.UpdateRagEngineConfigRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["update_rag_engine_config"] + + @property + def get_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.GetRagEngineConfigRequest], + Awaitable[vertex_rag_data.RagEngineConfig], + ]: + r"""Return a callable for the get rag engine config method over gRPC. + + Gets a RagEngineConfig. + + Returns: + Callable[[~.GetRagEngineConfigRequest], + Awaitable[~.RagEngineConfig]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_rag_engine_config" not in self._stubs: + self._stubs["get_rag_engine_config"] = self._logged_channel.unary_unary( + "/google.cloud.aiplatform.v1.VertexRagDataService/GetRagEngineConfig", + request_serializer=vertex_rag_data_service.GetRagEngineConfigRequest.serialize, + response_deserializer=vertex_rag_data.RagEngineConfig.deserialize, + ) + return self._stubs["get_rag_engine_config"] + def _prep_wrapped_messages(self, client_info): """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" self._wrapped_methods = { @@ -694,6 +752,16 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.update_rag_engine_config: self._wrap_method( + self.update_rag_engine_config, + default_timeout=None, + client_info=client_info, + ), + self.get_rag_engine_config: self._wrap_method( + self.get_rag_engine_config, + default_timeout=None, + client_info=client_info, + ), self.get_location: self._wrap_method( self.get_location, default_timeout=None, diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest.py index 0475f53e48..c22350780c 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest.py @@ -116,6 +116,14 @@ def post_get_rag_corpus(self, response): logging.log(f"Received response: {response}") return response + def pre_get_rag_engine_config(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_rag_engine_config(self, response): + logging.log(f"Received response: {response}") + return response + def pre_get_rag_file(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -156,6 +164,14 @@ def post_update_rag_corpus(self, response): logging.log(f"Received response: {response}") return response + def pre_update_rag_engine_config(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_rag_engine_config(self, response): + logging.log(f"Received response: {response}") + return response + def pre_upload_rag_file(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -366,6 +382,57 @@ def post_get_rag_corpus_with_metadata( """ return response, metadata + def pre_get_rag_engine_config( + self, + request: vertex_rag_data_service.GetRagEngineConfigRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + vertex_rag_data_service.GetRagEngineConfigRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for get_rag_engine_config + + Override in a subclass to manipulate the request or metadata + before they are sent to the VertexRagDataService server. + """ + return request, metadata + + def post_get_rag_engine_config( + self, response: vertex_rag_data.RagEngineConfig + ) -> vertex_rag_data.RagEngineConfig: + """Post-rpc interceptor for get_rag_engine_config + + DEPRECATED. Please use the `post_get_rag_engine_config_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the VertexRagDataService server but before + it is returned to user code. This `post_get_rag_engine_config` interceptor runs + before the `post_get_rag_engine_config_with_metadata` interceptor. + """ + return response + + def post_get_rag_engine_config_with_metadata( + self, + response: vertex_rag_data.RagEngineConfig, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + vertex_rag_data.RagEngineConfig, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for get_rag_engine_config + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the VertexRagDataService server but before it is returned to user code. + + We recommend only using this `post_get_rag_engine_config_with_metadata` + interceptor in new development instead of the `post_get_rag_engine_config` interceptor. + When both interceptors are used, this `post_get_rag_engine_config_with_metadata` interceptor runs after the + `post_get_rag_engine_config` interceptor. The (possibly modified) response returned by + `post_get_rag_engine_config` will be passed to + `post_get_rag_engine_config_with_metadata`. + """ + return response, metadata + def pre_get_rag_file( self, request: vertex_rag_data_service.GetRagFileRequest, @@ -617,6 +684,55 @@ def post_update_rag_corpus_with_metadata( """ return response, metadata + def pre_update_rag_engine_config( + self, + request: vertex_rag_data_service.UpdateRagEngineConfigRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + vertex_rag_data_service.UpdateRagEngineConfigRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for update_rag_engine_config + + Override in a subclass to manipulate the request or metadata + before they are sent to the VertexRagDataService server. + """ + return request, metadata + + def post_update_rag_engine_config( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for update_rag_engine_config + + DEPRECATED. Please use the `post_update_rag_engine_config_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the VertexRagDataService server but before + it is returned to user code. This `post_update_rag_engine_config` interceptor runs + before the `post_update_rag_engine_config_with_metadata` interceptor. + """ + return response + + def post_update_rag_engine_config_with_metadata( + self, + response: operations_pb2.Operation, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[operations_pb2.Operation, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for update_rag_engine_config + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the VertexRagDataService server but before it is returned to user code. + + We recommend only using this `post_update_rag_engine_config_with_metadata` + interceptor in new development instead of the `post_update_rag_engine_config` interceptor. + When both interceptors are used, this `post_update_rag_engine_config_with_metadata` interceptor runs after the + `post_update_rag_engine_config` interceptor. The (possibly modified) response returned by + `post_update_rag_engine_config` will be passed to + `post_update_rag_engine_config_with_metadata`. + """ + return response, metadata + def pre_upload_rag_file( self, request: vertex_rag_data_service.UploadRagFileRequest, @@ -3558,6 +3674,157 @@ def __call__( ) return resp + class _GetRagEngineConfig( + _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig, + VertexRagDataServiceRestStub, + ): + def __hash__(self): + return hash("VertexRagDataServiceRestTransport.GetRagEngineConfig") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: vertex_rag_data_service.GetRagEngineConfigRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> vertex_rag_data.RagEngineConfig: + r"""Call the get rag engine config method over HTTP. + + Args: + request (~.vertex_rag_data_service.GetRagEngineConfigRequest): + The request object. Request message for + [VertexRagDataService.GetRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.GetRagEngineConfig] + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.vertex_rag_data.RagEngineConfig: + Config for RagEngine. + """ + + http_options = ( + _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_rag_engine_config( + request, metadata + ) + transcoded_request = _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.cloud.aiplatform_v1.VertexRagDataServiceClient.GetRagEngineConfig", + extra={ + "serviceName": "google.cloud.aiplatform.v1.VertexRagDataService", + "rpcName": "GetRagEngineConfig", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ( + VertexRagDataServiceRestTransport._GetRagEngineConfig._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = vertex_rag_data.RagEngineConfig() + pb_resp = vertex_rag_data.RagEngineConfig.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_rag_engine_config(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_get_rag_engine_config_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = vertex_rag_data.RagEngineConfig.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.cloud.aiplatform_v1.VertexRagDataServiceClient.get_rag_engine_config", + extra={ + "serviceName": "google.cloud.aiplatform.v1.VertexRagDataService", + "rpcName": "GetRagEngineConfig", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + class _GetRagFile( _BaseVertexRagDataServiceRestTransport._BaseGetRagFile, VertexRagDataServiceRestStub, @@ -4323,6 +4590,164 @@ def __call__( ) return resp + class _UpdateRagEngineConfig( + _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig, + VertexRagDataServiceRestStub, + ): + def __hash__(self): + return hash("VertexRagDataServiceRestTransport.UpdateRagEngineConfig") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: vertex_rag_data_service.UpdateRagEngineConfigRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the update rag engine config method over HTTP. + + Args: + request (~.vertex_rag_data_service.UpdateRagEngineConfigRequest): + The request object. Request message for + [VertexRagDataService.UpdateRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagEngineConfig]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options = ( + _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_http_options() + ) + + request, metadata = self._interceptor.pre_update_rag_engine_config( + request, metadata + ) + transcoded_request = _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_transcoded_request( + http_options, request + ) + + body = _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.cloud.aiplatform_v1.VertexRagDataServiceClient.UpdateRagEngineConfig", + extra={ + "serviceName": "google.cloud.aiplatform.v1.VertexRagDataService", + "rpcName": "UpdateRagEngineConfig", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ( + VertexRagDataServiceRestTransport._UpdateRagEngineConfig._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_update_rag_engine_config(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = self._interceptor.post_update_rag_engine_config_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.cloud.aiplatform_v1.VertexRagDataServiceClient.update_rag_engine_config", + extra={ + "serviceName": "google.cloud.aiplatform.v1.VertexRagDataService", + "rpcName": "UpdateRagEngineConfig", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + class _UploadRagFile( _BaseVertexRagDataServiceRestTransport._BaseUploadRagFile, VertexRagDataServiceRestStub, @@ -4520,6 +4945,17 @@ def get_rag_corpus( # In C++ this would require a dynamic_cast return self._GetRagCorpus(self._session, self._host, self._interceptor) # type: ignore + @property + def get_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.GetRagEngineConfigRequest], + vertex_rag_data.RagEngineConfig, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetRagEngineConfig(self._session, self._host, self._interceptor) # type: ignore + @property def get_rag_file( self, @@ -4570,6 +5006,16 @@ def update_rag_corpus( # In C++ this would require a dynamic_cast return self._UpdateRagCorpus(self._session, self._host, self._interceptor) # type: ignore + @property + def update_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagEngineConfigRequest], operations_pb2.Operation + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateRagEngineConfig(self._session, self._host, self._interceptor) # type: ignore + @property def upload_rag_file( self, diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest_asyncio.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest_asyncio.py index 92f95a0e6a..db18cfef3d 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest_asyncio.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest_asyncio.py @@ -133,6 +133,14 @@ async def post_get_rag_corpus(self, response): logging.log(f"Received response: {response}") return response + async def pre_get_rag_engine_config(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + async def post_get_rag_engine_config(self, response): + logging.log(f"Received response: {response}") + return response + async def pre_get_rag_file(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -173,6 +181,14 @@ async def post_update_rag_corpus(self, response): logging.log(f"Received response: {response}") return response + async def pre_update_rag_engine_config(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + async def post_update_rag_engine_config(self, response): + logging.log(f"Received response: {response}") + return response + async def pre_upload_rag_file(self, request, metadata): logging.log(f"Received request: {request}") return request, metadata @@ -383,6 +399,57 @@ async def post_get_rag_corpus_with_metadata( """ return response, metadata + async def pre_get_rag_engine_config( + self, + request: vertex_rag_data_service.GetRagEngineConfigRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + vertex_rag_data_service.GetRagEngineConfigRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for get_rag_engine_config + + Override in a subclass to manipulate the request or metadata + before they are sent to the VertexRagDataService server. + """ + return request, metadata + + async def post_get_rag_engine_config( + self, response: vertex_rag_data.RagEngineConfig + ) -> vertex_rag_data.RagEngineConfig: + """Post-rpc interceptor for get_rag_engine_config + + DEPRECATED. Please use the `post_get_rag_engine_config_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the VertexRagDataService server but before + it is returned to user code. This `post_get_rag_engine_config` interceptor runs + before the `post_get_rag_engine_config_with_metadata` interceptor. + """ + return response + + async def post_get_rag_engine_config_with_metadata( + self, + response: vertex_rag_data.RagEngineConfig, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + vertex_rag_data.RagEngineConfig, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Post-rpc interceptor for get_rag_engine_config + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the VertexRagDataService server but before it is returned to user code. + + We recommend only using this `post_get_rag_engine_config_with_metadata` + interceptor in new development instead of the `post_get_rag_engine_config` interceptor. + When both interceptors are used, this `post_get_rag_engine_config_with_metadata` interceptor runs after the + `post_get_rag_engine_config` interceptor. The (possibly modified) response returned by + `post_get_rag_engine_config` will be passed to + `post_get_rag_engine_config_with_metadata`. + """ + return response, metadata + async def pre_get_rag_file( self, request: vertex_rag_data_service.GetRagFileRequest, @@ -634,6 +701,55 @@ async def post_update_rag_corpus_with_metadata( """ return response, metadata + async def pre_update_rag_engine_config( + self, + request: vertex_rag_data_service.UpdateRagEngineConfigRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + vertex_rag_data_service.UpdateRagEngineConfigRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for update_rag_engine_config + + Override in a subclass to manipulate the request or metadata + before they are sent to the VertexRagDataService server. + """ + return request, metadata + + async def post_update_rag_engine_config( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for update_rag_engine_config + + DEPRECATED. Please use the `post_update_rag_engine_config_with_metadata` + interceptor instead. + + Override in a subclass to read or manipulate the response + after it is returned by the VertexRagDataService server but before + it is returned to user code. This `post_update_rag_engine_config` interceptor runs + before the `post_update_rag_engine_config_with_metadata` interceptor. + """ + return response + + async def post_update_rag_engine_config_with_metadata( + self, + response: operations_pb2.Operation, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[operations_pb2.Operation, Sequence[Tuple[str, Union[str, bytes]]]]: + """Post-rpc interceptor for update_rag_engine_config + + Override in a subclass to read or manipulate the response or metadata after it + is returned by the VertexRagDataService server but before it is returned to user code. + + We recommend only using this `post_update_rag_engine_config_with_metadata` + interceptor in new development instead of the `post_update_rag_engine_config` interceptor. + When both interceptors are used, this `post_update_rag_engine_config_with_metadata` interceptor runs after the + `post_update_rag_engine_config` interceptor. The (possibly modified) response returned by + `post_update_rag_engine_config` will be passed to + `post_update_rag_engine_config_with_metadata`. + """ + return response, metadata + async def pre_upload_rag_file( self, request: vertex_rag_data_service.UploadRagFileRequest, @@ -1051,6 +1167,16 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.update_rag_engine_config: self._wrap_method( + self.update_rag_engine_config, + default_timeout=None, + client_info=client_info, + ), + self.get_rag_engine_config: self._wrap_method( + self.get_rag_engine_config, + default_timeout=None, + client_info=client_info, + ), self.get_location: self._wrap_method( self.get_location, default_timeout=None, @@ -1746,6 +1872,161 @@ async def __call__( return resp + class _GetRagEngineConfig( + _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig, + AsyncVertexRagDataServiceRestStub, + ): + def __hash__(self): + return hash("AsyncVertexRagDataServiceRestTransport.GetRagEngineConfig") + + @staticmethod + async def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = await getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + async def __call__( + self, + request: vertex_rag_data_service.GetRagEngineConfigRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> vertex_rag_data.RagEngineConfig: + r"""Call the get rag engine config method over HTTP. + + Args: + request (~.vertex_rag_data_service.GetRagEngineConfigRequest): + The request object. Request message for + [VertexRagDataService.GetRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.GetRagEngineConfig] + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.vertex_rag_data.RagEngineConfig: + Config for RagEngine. + """ + + http_options = ( + _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig._get_http_options() + ) + + request, metadata = await self._interceptor.pre_get_rag_engine_config( + request, metadata + ) + transcoded_request = _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.cloud.aiplatform_v1.VertexRagDataServiceClient.GetRagEngineConfig", + extra={ + "serviceName": "google.cloud.aiplatform.v1.VertexRagDataService", + "rpcName": "GetRagEngineConfig", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = await AsyncVertexRagDataServiceRestTransport._GetRagEngineConfig._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + content = await response.read() + payload = json.loads(content.decode("utf-8")) + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + raise core_exceptions.format_http_response_error(response, method, request_url, payload) # type: ignore + + # Return the response + resp = vertex_rag_data.RagEngineConfig() + pb_resp = vertex_rag_data.RagEngineConfig.pb(resp) + content = await response.read() + json_format.Parse(content, pb_resp, ignore_unknown_fields=True) + resp = await self._interceptor.post_get_rag_engine_config(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + resp, _ = await self._interceptor.post_get_rag_engine_config_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = vertex_rag_data.RagEngineConfig.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": "OK", # need to obtain this properly + } + _LOGGER.debug( + "Received response for google.cloud.aiplatform_v1.VertexRagDataServiceAsyncClient.get_rag_engine_config", + extra={ + "serviceName": "google.cloud.aiplatform.v1.VertexRagDataService", + "rpcName": "GetRagEngineConfig", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + + return resp + class _GetRagFile( _BaseVertexRagDataServiceRestTransport._BaseGetRagFile, AsyncVertexRagDataServiceRestStub, @@ -2551,6 +2832,173 @@ async def __call__( return resp + class _UpdateRagEngineConfig( + _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig, + AsyncVertexRagDataServiceRestStub, + ): + def __hash__(self): + return hash("AsyncVertexRagDataServiceRestTransport.UpdateRagEngineConfig") + + @staticmethod + async def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = await getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + async def __call__( + self, + request: vertex_rag_data_service.UpdateRagEngineConfigRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the update rag engine config method over HTTP. + + Args: + request (~.vertex_rag_data_service.UpdateRagEngineConfigRequest): + The request object. Request message for + [VertexRagDataService.UpdateRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagEngineConfig]. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options = ( + _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_http_options() + ) + + request, metadata = await self._interceptor.pre_update_rag_engine_config( + request, metadata + ) + transcoded_request = _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_transcoded_request( + http_options, request + ) + + body = _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.cloud.aiplatform_v1.VertexRagDataServiceClient.UpdateRagEngineConfig", + extra={ + "serviceName": "google.cloud.aiplatform.v1.VertexRagDataService", + "rpcName": "UpdateRagEngineConfig", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = await AsyncVertexRagDataServiceRestTransport._UpdateRagEngineConfig._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + content = await response.read() + payload = json.loads(content.decode("utf-8")) + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + raise core_exceptions.format_http_response_error(response, method, request_url, payload) # type: ignore + + # Return the response + resp = operations_pb2.Operation() + pb_resp = resp + content = await response.read() + json_format.Parse(content, pb_resp, ignore_unknown_fields=True) + resp = await self._interceptor.post_update_rag_engine_config(resp) + response_metadata = [(k, str(v)) for k, v in response.headers.items()] + ( + resp, + _, + ) = await self._interceptor.post_update_rag_engine_config_with_metadata( + resp, response_metadata + ) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": "OK", # need to obtain this properly + } + _LOGGER.debug( + "Received response for google.cloud.aiplatform_v1.VertexRagDataServiceAsyncClient.update_rag_engine_config", + extra={ + "serviceName": "google.cloud.aiplatform.v1.VertexRagDataService", + "rpcName": "UpdateRagEngineConfig", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + + return resp + class _UploadRagFile( _BaseVertexRagDataServiceRestTransport._BaseUploadRagFile, AsyncVertexRagDataServiceRestStub, @@ -4701,6 +5149,15 @@ def get_rag_corpus( ]: return self._GetRagCorpus(self._session, self._host, self._interceptor) # type: ignore + @property + def get_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.GetRagEngineConfigRequest], + vertex_rag_data.RagEngineConfig, + ]: + return self._GetRagEngineConfig(self._session, self._host, self._interceptor) # type: ignore + @property def get_rag_file( self, @@ -4741,6 +5198,14 @@ def update_rag_corpus( ]: return self._UpdateRagCorpus(self._session, self._host, self._interceptor) # type: ignore + @property + def update_rag_engine_config( + self, + ) -> Callable[ + [vertex_rag_data_service.UpdateRagEngineConfigRequest], operations_pb2.Operation + ]: + return self._UpdateRagEngineConfig(self._session, self._host, self._interceptor) # type: ignore + @property def upload_rag_file( self, diff --git a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest_base.py b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest_base.py index e379929493..fd4a419037 100644 --- a/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest_base.py +++ b/google/cloud/aiplatform_v1/services/vertex_rag_data_service/transports/rest_base.py @@ -292,6 +292,53 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params + class _BaseGetRagEngineConfig: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1/{name=projects/*/locations/*/ragEngineConfig}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = vertex_rag_data_service.GetRagEngineConfigRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseVertexRagDataServiceRestTransport._BaseGetRagEngineConfig._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseGetRagFile: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") @@ -547,6 +594,65 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params + class _BaseUpdateRagEngineConfig: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1/{rag_engine_config.name=projects/*/locations/*/ragEngineConfig}", + "body": "rag_engine_config", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = vertex_rag_data_service.UpdateRagEngineConfigRequest.pb( + request + ) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseVertexRagDataServiceRestTransport._BaseUpdateRagEngineConfig._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseUploadRagFile: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index be706b6d02..20ea774240 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -953,6 +953,7 @@ UpdateScheduleRequest, ) from .service_networking import ( + DnsPeeringConfig, PrivateServiceConnectConfig, PscAutomatedEndpoints, PSCAutomationConfig, @@ -1109,10 +1110,12 @@ RagChunk, RagCorpus, RagEmbeddingModelConfig, + RagEngineConfig, RagFile, RagFileChunkingConfig, RagFileParsingConfig, RagFileTransformationConfig, + RagManagedDbConfig, RagVectorDbConfig, UploadRagFileConfig, VertexAiSearchConfig, @@ -1123,6 +1126,7 @@ DeleteRagCorpusRequest, DeleteRagFileRequest, GetRagCorpusRequest, + GetRagEngineConfigRequest, GetRagFileRequest, ImportRagFilesOperationMetadata, ImportRagFilesRequest, @@ -1133,6 +1137,8 @@ ListRagFilesResponse, UpdateRagCorpusOperationMetadata, UpdateRagCorpusRequest, + UpdateRagEngineConfigOperationMetadata, + UpdateRagEngineConfigRequest, UploadRagFileRequest, UploadRagFileResponse, ) @@ -1917,6 +1923,7 @@ "PauseScheduleRequest", "ResumeScheduleRequest", "UpdateScheduleRequest", + "DnsPeeringConfig", "PrivateServiceConnectConfig", "PscAutomatedEndpoints", "PSCAutomationConfig", @@ -2039,10 +2046,12 @@ "RagChunk", "RagCorpus", "RagEmbeddingModelConfig", + "RagEngineConfig", "RagFile", "RagFileChunkingConfig", "RagFileParsingConfig", "RagFileTransformationConfig", + "RagManagedDbConfig", "RagVectorDbConfig", "UploadRagFileConfig", "VertexAiSearchConfig", @@ -2051,6 +2060,7 @@ "DeleteRagCorpusRequest", "DeleteRagFileRequest", "GetRagCorpusRequest", + "GetRagEngineConfigRequest", "GetRagFileRequest", "ImportRagFilesOperationMetadata", "ImportRagFilesRequest", @@ -2061,6 +2071,8 @@ "ListRagFilesResponse", "UpdateRagCorpusOperationMetadata", "UpdateRagCorpusRequest", + "UpdateRagEngineConfigOperationMetadata", + "UpdateRagEngineConfigRequest", "UploadRagFileRequest", "UploadRagFileResponse", "AugmentPromptRequest", diff --git a/google/cloud/aiplatform_v1/types/service_networking.py b/google/cloud/aiplatform_v1/types/service_networking.py index 42fdd90d59..b5bb34a41f 100644 --- a/google/cloud/aiplatform_v1/types/service_networking.py +++ b/google/cloud/aiplatform_v1/types/service_networking.py @@ -27,6 +27,7 @@ "PrivateServiceConnectConfig", "PscAutomatedEndpoints", "PscInterfaceConfig", + "DnsPeeringConfig", }, ) @@ -130,12 +131,60 @@ class PscInterfaceConfig(proto.Message): [created a network attachment] (https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments). This field is only used for resources using PSC-I. + dns_peering_configs (MutableSequence[google.cloud.aiplatform_v1.types.DnsPeeringConfig]): + Optional. DNS peering configurations. When + specified, Vertex AI will attempt to configure + DNS peering zones in the tenant project VPC to + resolve the specified domains using the target + network's Cloud DNS. The user must grant the + dns.peer role to the Vertex AI Service Agent on + the target project. """ network_attachment: str = proto.Field( proto.STRING, number=1, ) + dns_peering_configs: MutableSequence["DnsPeeringConfig"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="DnsPeeringConfig", + ) + + +class DnsPeeringConfig(proto.Message): + r"""DNS peering configuration. These configurations are used to + create DNS peering zones in the Vertex tenant project VPC, + enabling resolution of records within the specified domain + hosted in the target network's Cloud DNS. + + Attributes: + domain (str): + Required. The DNS name suffix of the zone + being peered to, e.g., + "my-internal-domain.corp.". Must end with a dot. + target_project (str): + Required. The project ID hosting the Cloud + DNS managed zone that contains the 'domain'. The + Vertex AI Service Agent requires the dns.peer + role on this project. + target_network (str): + Required. The VPC network name in the target_project where + the DNS zone specified by 'domain' is visible. + """ + + domain: str = proto.Field( + proto.STRING, + number=1, + ) + target_project: str = proto.Field( + proto.STRING, + number=2, + ) + target_network: str = proto.Field( + proto.STRING, + number=3, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1/types/vertex_rag_data.py index fcab74428f..91b3e505cf 100644 --- a/google/cloud/aiplatform_v1/types/vertex_rag_data.py +++ b/google/cloud/aiplatform_v1/types/vertex_rag_data.py @@ -41,6 +41,8 @@ "RagFileParsingConfig", "UploadRagFileConfig", "ImportRagFilesConfig", + "RagManagedDbConfig", + "RagEngineConfig", }, ) @@ -988,4 +990,104 @@ class ImportRagFilesConfig(proto.Message): ) +class RagManagedDbConfig(proto.Message): + r"""Configuration message for RagManagedDb used by RagEngine. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + scaled (google.cloud.aiplatform_v1.types.RagManagedDbConfig.Scaled): + Sets the RagManagedDb to the Scaled tier. + + This field is a member of `oneof`_ ``tier``. + basic (google.cloud.aiplatform_v1.types.RagManagedDbConfig.Basic): + Sets the RagManagedDb to the Basic tier. + + This field is a member of `oneof`_ ``tier``. + unprovisioned (google.cloud.aiplatform_v1.types.RagManagedDbConfig.Unprovisioned): + Sets the RagManagedDb to the Unprovisioned + tier. + + This field is a member of `oneof`_ ``tier``. + """ + + class Scaled(proto.Message): + r"""Scaled tier offers production grade performance along with + autoscaling functionality. It is suitable for customers with + large amounts of data or performance sensitive workloads. + + """ + + class Basic(proto.Message): + r"""Basic tier is a cost-effective and low compute tier suitable for the + following cases: + + - Experimenting with RagManagedDb. + - Small data size. + - Latency insensitive workload. + - Only using RAG Engine with external vector DBs. + + NOTE: This is the default tier if not explicitly chosen. + + """ + + class Unprovisioned(proto.Message): + r"""Disables the RAG Engine service and deletes all your data + held within this service. This will halt the billing of the + service. + + NOTE: Once deleted the data cannot be recovered. To start using + RAG Engine again, you will need to update the tier by calling + the UpdateRagEngineConfig API. + + """ + + scaled: Scaled = proto.Field( + proto.MESSAGE, + number=4, + oneof="tier", + message=Scaled, + ) + basic: Basic = proto.Field( + proto.MESSAGE, + number=2, + oneof="tier", + message=Basic, + ) + unprovisioned: Unprovisioned = proto.Field( + proto.MESSAGE, + number=3, + oneof="tier", + message=Unprovisioned, + ) + + +class RagEngineConfig(proto.Message): + r"""Config for RagEngine. + + Attributes: + name (str): + Identifier. The name of the RagEngineConfig. Format: + ``projects/{project}/locations/{location}/ragEngineConfig`` + rag_managed_db_config (google.cloud.aiplatform_v1.types.RagManagedDbConfig): + The config of the RagManagedDb used by + RagEngine. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + rag_managed_db_config: "RagManagedDbConfig" = proto.Field( + proto.MESSAGE, + number=2, + message="RagManagedDbConfig", + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/vertex_rag_data_service.py b/google/cloud/aiplatform_v1/types/vertex_rag_data_service.py index b3458b855f..bb1f3336f0 100644 --- a/google/cloud/aiplatform_v1/types/vertex_rag_data_service.py +++ b/google/cloud/aiplatform_v1/types/vertex_rag_data_service.py @@ -41,9 +41,12 @@ "ListRagFilesResponse", "DeleteRagFileRequest", "CreateRagCorpusOperationMetadata", + "GetRagEngineConfigRequest", "UpdateRagCorpusRequest", "UpdateRagCorpusOperationMetadata", "ImportRagFilesOperationMetadata", + "UpdateRagEngineConfigRequest", + "UpdateRagEngineConfigOperationMetadata", }, ) @@ -441,6 +444,22 @@ class CreateRagCorpusOperationMetadata(proto.Message): ) +class GetRagEngineConfigRequest(proto.Message): + r"""Request message for + [VertexRagDataService.GetRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.GetRagEngineConfig] + + Attributes: + name (str): + Required. The name of the RagEngineConfig resource. Format: + ``projects/{project}/locations/{location}/ragEngineConfig`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + class UpdateRagCorpusRequest(proto.Message): r"""Request message for [VertexRagDataService.UpdateRagCorpus][google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagCorpus]. @@ -514,4 +533,40 @@ class ImportRagFilesOperationMetadata(proto.Message): ) +class UpdateRagEngineConfigRequest(proto.Message): + r"""Request message for + [VertexRagDataService.UpdateRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagEngineConfig]. + + Attributes: + rag_engine_config (google.cloud.aiplatform_v1.types.RagEngineConfig): + Required. The updated RagEngineConfig. + + NOTE: Downgrading your RagManagedDb's + ComputeTier could temporarily increase request + latencies until the operation is fully complete. + """ + + rag_engine_config: vertex_rag_data.RagEngineConfig = proto.Field( + proto.MESSAGE, + number=1, + message=vertex_rag_data.RagEngineConfig, + ) + + +class UpdateRagEngineConfigOperationMetadata(proto.Message): + r"""Runtime operation information for + [VertexRagDataService.UpdateRagEngineConfig][google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagEngineConfig]. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata: operation.GenericOperationMetadata = proto.Field( + proto.MESSAGE, + number=1, + message=operation.GenericOperationMetadata, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 2782cd4534..abb871b44e 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -1053,6 +1053,7 @@ from .types.schedule_service import PauseScheduleRequest from .types.schedule_service import ResumeScheduleRequest from .types.schedule_service import UpdateScheduleRequest +from .types.service_networking import DnsPeeringConfig from .types.service_networking import PrivateServiceConnectConfig from .types.service_networking import PscAutomatedEndpoints from .types.service_networking import PSCAutomationConfig @@ -1587,6 +1588,7 @@ "DistillationDataStats", "DistillationHyperParameters", "DistillationSpec", + "DnsPeeringConfig", "DoubleArray", "DynamicRetrievalConfig", "EncryptionSpec", diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 664e401502..1a7670e5f5 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -1166,6 +1166,7 @@ UpdateScheduleRequest, ) from .service_networking import ( + DnsPeeringConfig, PrivateServiceConnectConfig, PscAutomatedEndpoints, PSCAutomationConfig, @@ -2351,6 +2352,7 @@ "PauseScheduleRequest", "ResumeScheduleRequest", "UpdateScheduleRequest", + "DnsPeeringConfig", "PrivateServiceConnectConfig", "PscAutomatedEndpoints", "PSCAutomationConfig", diff --git a/google/cloud/aiplatform_v1beta1/types/service_networking.py b/google/cloud/aiplatform_v1beta1/types/service_networking.py index f9fa1dfe29..8e89496812 100644 --- a/google/cloud/aiplatform_v1beta1/types/service_networking.py +++ b/google/cloud/aiplatform_v1beta1/types/service_networking.py @@ -27,6 +27,7 @@ "PrivateServiceConnectConfig", "PscAutomatedEndpoints", "PscInterfaceConfig", + "DnsPeeringConfig", }, ) @@ -145,12 +146,60 @@ class PscInterfaceConfig(proto.Message): attachment] (https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments). This field is only used for resources using PSC-I. + dns_peering_configs (MutableSequence[google.cloud.aiplatform_v1beta1.types.DnsPeeringConfig]): + Optional. DNS peering configurations. When + specified, Vertex AI will attempt to configure + DNS peering zones in the tenant project VPC to + resolve the specified domains using the target + network's Cloud DNS. The user must grant the + dns.peer role to the Vertex AI Service Agent on + the target project. """ network_attachment: str = proto.Field( proto.STRING, number=1, ) + dns_peering_configs: MutableSequence["DnsPeeringConfig"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="DnsPeeringConfig", + ) + + +class DnsPeeringConfig(proto.Message): + r"""DNS peering configuration. These configurations are used to + create DNS peering zones in the Vertex tenant project VPC, + enabling resolution of records within the specified domain + hosted in the target network's Cloud DNS. + + Attributes: + domain (str): + Required. The DNS name suffix of the zone + being peered to, e.g., + "my-internal-domain.corp.". Must end with a dot. + target_project (str): + Required. The project ID hosting the Cloud + DNS managed zone that contains the 'domain'. The + Vertex AI Service Agent requires the dns.peer + role on this project. + target_network (str): + Required. The VPC network name in the target_project where + the DNS zone specified by 'domain' is visible. + """ + + domain: str = proto.Field( + proto.STRING, + number=1, + ) + target_project: str = proto.Field( + proto.STRING, + number=2, + ) + target_network: str = proto.Field( + proto.STRING, + number=3, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py index 557ce4d6bd..3906996e5b 100644 --- a/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py +++ b/google/cloud/aiplatform_v1beta1/types/vertex_rag_data.py @@ -1497,23 +1497,38 @@ class RagManagedDbConfig(proto.Message): Attributes: enterprise (google.cloud.aiplatform_v1beta1.types.RagManagedDbConfig.Enterprise): - Sets the RagManagedDb to the Enterprise tier. - This is the default tier if not explicitly - chosen. + Deprecated: Please use ``Scaled`` tier instead. Sets the + RagManagedDb to the Enterprise tier. This is the default + tier if not explicitly chosen. + + This field is a member of `oneof`_ ``tier``. + scaled (google.cloud.aiplatform_v1beta1.types.RagManagedDbConfig.Scaled): + Sets the RagManagedDb to the Scaled tier. This field is a member of `oneof`_ ``tier``. basic (google.cloud.aiplatform_v1beta1.types.RagManagedDbConfig.Basic): Sets the RagManagedDb to the Basic tier. + This field is a member of `oneof`_ ``tier``. + unprovisioned (google.cloud.aiplatform_v1beta1.types.RagManagedDbConfig.Unprovisioned): + Sets the RagManagedDb to the Unprovisioned + tier. + This field is a member of `oneof`_ ``tier``. """ class Enterprise(proto.Message): - r"""Enterprise tier offers production grade performance along - with autoscaling functionality. It is suitable for customers - with large amounts of data or performance sensitive workloads. + r"""Deprecated: Please use ``Scaled`` tier instead. Enterprise tier + offers production grade performance along with autoscaling + functionality. It is suitable for customers with large amounts of + data or performance sensitive workloads. - NOTE: This is the default tier if not explicitly chosen. + """ + + class Scaled(proto.Message): + r"""Scaled tier offers production grade performance along with + autoscaling functionality. It is suitable for customers with + large amounts of data or performance sensitive workloads. """ @@ -1526,6 +1541,19 @@ class Basic(proto.Message): - Latency insensitive workload. - Only using RAG Engine with external vector DBs. + NOTE: This is the default tier if not explicitly chosen. + + """ + + class Unprovisioned(proto.Message): + r"""Disables the RAG Engine service and deletes all your data + held within this service. This will halt the billing of the + service. + + NOTE: Once deleted the data cannot be recovered. To start using + RAG Engine again, you will need to update the tier by calling + the UpdateRagEngineConfig API. + """ enterprise: Enterprise = proto.Field( @@ -1534,12 +1562,24 @@ class Basic(proto.Message): oneof="tier", message=Enterprise, ) + scaled: Scaled = proto.Field( + proto.MESSAGE, + number=4, + oneof="tier", + message=Scaled, + ) basic: Basic = proto.Field( proto.MESSAGE, number=2, oneof="tier", message=Basic, ) + unprovisioned: Unprovisioned = proto.Field( + proto.MESSAGE, + number=3, + oneof="tier", + message=Unprovisioned, + ) class RagEngineConfig(proto.Message): diff --git a/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_async.py b/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_async.py new file mode 100644 index 0000000000..74ec21c802 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetRagEngineConfig +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_VertexRagDataService_GetRagEngineConfig_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_get_rag_engine_config(): + # Create a client + client = aiplatform_v1.VertexRagDataServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.GetRagEngineConfigRequest( + name="name_value", + ) + + # Make the request + response = await client.get_rag_engine_config(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_VertexRagDataService_GetRagEngineConfig_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_sync.py b/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_sync.py new file mode 100644 index 0000000000..3560db6861 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetRagEngineConfig +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_VertexRagDataService_GetRagEngineConfig_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_get_rag_engine_config(): + # Create a client + client = aiplatform_v1.VertexRagDataServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.GetRagEngineConfigRequest( + name="name_value", + ) + + # Make the request + response = client.get_rag_engine_config(request=request) + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_VertexRagDataService_GetRagEngineConfig_sync] diff --git a/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_async.py b/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_async.py new file mode 100644 index 0000000000..bdecd52509 --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_async.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateRagEngineConfig +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_VertexRagDataService_UpdateRagEngineConfig_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +async def sample_update_rag_engine_config(): + # Create a client + client = aiplatform_v1.VertexRagDataServiceAsyncClient() + + # Initialize request argument(s) + request = aiplatform_v1.UpdateRagEngineConfigRequest( + ) + + # Make the request + operation = client.update_rag_engine_config(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_VertexRagDataService_UpdateRagEngineConfig_async] diff --git a/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_sync.py b/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_sync.py new file mode 100644 index 0000000000..3a1d8671dc --- /dev/null +++ b/samples/generated_samples/aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_sync.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateRagEngineConfig +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-cloud-aiplatform + + +# [START aiplatform_v1_generated_VertexRagDataService_UpdateRagEngineConfig_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.cloud import aiplatform_v1 + + +def sample_update_rag_engine_config(): + # Create a client + client = aiplatform_v1.VertexRagDataServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1.UpdateRagEngineConfigRequest( + ) + + # Make the request + operation = client.update_rag_engine_config(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + +# [END aiplatform_v1_generated_VertexRagDataService_UpdateRagEngineConfig_sync] diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index cbfa30ab93..f8471f4aef 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -49732,6 +49732,167 @@ ], "title": "aiplatform_v1_generated_vertex_rag_data_service_get_rag_corpus_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.VertexRagDataServiceAsyncClient", + "shortName": "VertexRagDataServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.VertexRagDataServiceAsyncClient.get_rag_engine_config", + "method": { + "fullName": "google.cloud.aiplatform.v1.VertexRagDataService.GetRagEngineConfig", + "service": { + "fullName": "google.cloud.aiplatform.v1.VertexRagDataService", + "shortName": "VertexRagDataService" + }, + "shortName": "GetRagEngineConfig" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.GetRagEngineConfigRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.RagEngineConfig", + "shortName": "get_rag_engine_config" + }, + "description": "Sample for GetRagEngineConfig", + "file": "aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_VertexRagDataService_GetRagEngineConfig_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.VertexRagDataServiceClient", + "shortName": "VertexRagDataServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.VertexRagDataServiceClient.get_rag_engine_config", + "method": { + "fullName": "google.cloud.aiplatform.v1.VertexRagDataService.GetRagEngineConfig", + "service": { + "fullName": "google.cloud.aiplatform.v1.VertexRagDataService", + "shortName": "VertexRagDataService" + }, + "shortName": "GetRagEngineConfig" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.GetRagEngineConfigRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.cloud.aiplatform_v1.types.RagEngineConfig", + "shortName": "get_rag_engine_config" + }, + "description": "Sample for GetRagEngineConfig", + "file": "aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_VertexRagDataService_GetRagEngineConfig_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_vertex_rag_data_service_get_rag_engine_config_sync.py" + }, { "canonical": true, "clientMethod": { @@ -50545,6 +50706,167 @@ ], "title": "aiplatform_v1_generated_vertex_rag_data_service_update_rag_corpus_sync.py" }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.cloud.aiplatform_v1.VertexRagDataServiceAsyncClient", + "shortName": "VertexRagDataServiceAsyncClient" + }, + "fullName": "google.cloud.aiplatform_v1.VertexRagDataServiceAsyncClient.update_rag_engine_config", + "method": { + "fullName": "google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagEngineConfig", + "service": { + "fullName": "google.cloud.aiplatform.v1.VertexRagDataService", + "shortName": "VertexRagDataService" + }, + "shortName": "UpdateRagEngineConfig" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.UpdateRagEngineConfigRequest" + }, + { + "name": "rag_engine_config", + "type": "google.cloud.aiplatform_v1.types.RagEngineConfig" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.api_core.operation_async.AsyncOperation", + "shortName": "update_rag_engine_config" + }, + "description": "Sample for UpdateRagEngineConfig", + "file": "aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_VertexRagDataService_UpdateRagEngineConfig_async", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.cloud.aiplatform_v1.VertexRagDataServiceClient", + "shortName": "VertexRagDataServiceClient" + }, + "fullName": "google.cloud.aiplatform_v1.VertexRagDataServiceClient.update_rag_engine_config", + "method": { + "fullName": "google.cloud.aiplatform.v1.VertexRagDataService.UpdateRagEngineConfig", + "service": { + "fullName": "google.cloud.aiplatform.v1.VertexRagDataService", + "shortName": "VertexRagDataService" + }, + "shortName": "UpdateRagEngineConfig" + }, + "parameters": [ + { + "name": "request", + "type": "google.cloud.aiplatform_v1.types.UpdateRagEngineConfigRequest" + }, + { + "name": "rag_engine_config", + "type": "google.cloud.aiplatform_v1.types.RagEngineConfig" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.api_core.operation.Operation", + "shortName": "update_rag_engine_config" + }, + "description": "Sample for UpdateRagEngineConfig", + "file": "aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "aiplatform_v1_generated_VertexRagDataService_UpdateRagEngineConfig_sync", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "aiplatform_v1_generated_vertex_rag_data_service_update_rag_engine_config_sync.py" + }, { "canonical": true, "clientMethod": { diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index 7d8ddd773c..2895f8f9cd 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -24434,7 +24434,16 @@ def test_create_custom_job_rest_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "base_output_directory": {"output_uri_prefix": "output_uri_prefix_value"}, "protected_artifact_location_id": "protected_artifact_location_id_value", "tensorboard": "tensorboard_value", @@ -26083,7 +26092,16 @@ def test_create_hyperparameter_tuning_job_rest_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "base_output_directory": {"output_uri_prefix": "output_uri_prefix_value"}, "protected_artifact_location_id": "protected_artifact_location_id_value", "tensorboard": "tensorboard_value", @@ -26961,7 +26979,14 @@ def test_create_nas_job_rest_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "base_output_directory": { "output_uri_prefix": "output_uri_prefix_value" @@ -31842,7 +31867,16 @@ async def test_create_custom_job_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "base_output_directory": {"output_uri_prefix": "output_uri_prefix_value"}, "protected_artifact_location_id": "protected_artifact_location_id_value", "tensorboard": "tensorboard_value", @@ -33656,7 +33690,16 @@ async def test_create_hyperparameter_tuning_job_rest_asyncio_call_success(reques "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "base_output_directory": {"output_uri_prefix": "output_uri_prefix_value"}, "protected_artifact_location_id": "protected_artifact_location_id_value", "tensorboard": "tensorboard_value", @@ -34632,7 +34675,14 @@ async def test_create_nas_job_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "base_output_directory": { "output_uri_prefix": "output_uri_prefix_value" diff --git a/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py b/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py index e723a6ff25..3cab74c261 100644 --- a/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_persistent_resource_service.py @@ -5336,7 +5336,16 @@ def test_create_persistent_resource_rest_call_success(request_type): "update_time": {}, "labels": {}, "network": "network_value", - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "resource_runtime_spec": { "service_account_spec": { @@ -6025,7 +6034,16 @@ def test_update_persistent_resource_rest_call_success(request_type): "update_time": {}, "labels": {}, "network": "network_value", - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "resource_runtime_spec": { "service_account_spec": { @@ -7219,7 +7237,16 @@ async def test_create_persistent_resource_rest_asyncio_call_success(request_type "update_time": {}, "labels": {}, "network": "network_value", - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "resource_runtime_spec": { "service_account_spec": { @@ -7972,7 +7999,16 @@ async def test_update_persistent_resource_rest_asyncio_call_success(request_type "update_time": {}, "labels": {}, "network": "network_value", - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "resource_runtime_spec": { "service_account_spec": { diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index 0a6aba329d..fc5059da0c 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -10079,7 +10079,16 @@ def test_create_pipeline_job_rest_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, "schedule_name": "schedule_name_value", @@ -13098,7 +13107,16 @@ async def test_create_pipeline_job_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, "schedule_name": "schedule_name_value", diff --git a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py index 9ccef983c3..b53d344d73 100644 --- a/tests/unit/gapic/aiplatform_v1/test_schedule_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_schedule_service.py @@ -5701,7 +5701,14 @@ def test_create_schedule_rest_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, @@ -6720,7 +6727,14 @@ def test_update_schedule_rest_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, @@ -7929,7 +7943,14 @@ async def test_create_schedule_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, @@ -9048,7 +9069,14 @@ async def test_update_schedule_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, diff --git a/tests/unit/gapic/aiplatform_v1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1/test_vertex_rag_data_service.py index f9e920487e..044c4dc85a 100644 --- a/tests/unit/gapic/aiplatform_v1/test_vertex_rag_data_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_vertex_rag_data_service.py @@ -5084,6 +5084,692 @@ async def test_delete_rag_file_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.UpdateRagEngineConfigRequest, + dict, + ], +) +def test_update_rag_engine_config(request_type, transport: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.update_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.UpdateRagEngineConfigRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_rag_engine_config_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = vertex_rag_data_service.UpdateRagEngineConfigRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_rag_engine_config(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.UpdateRagEngineConfigRequest() + + +def test_update_rag_engine_config_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_rag_engine_config + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_rag_engine_config + ] = mock_rpc + request = {} + client.update_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_rag_engine_config(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_rag_engine_config_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_rag_engine_config + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.update_rag_engine_config + ] = mock_rpc + + request = {} + await client.update_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.update_rag_engine_config(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_rag_engine_config_async( + transport: str = "grpc_asyncio", + request_type=vertex_rag_data_service.UpdateRagEngineConfigRequest, +): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.update_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.UpdateRagEngineConfigRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_rag_engine_config_async_from_dict(): + await test_update_rag_engine_config_async(request_type=dict) + + +def test_update_rag_engine_config_field_headers(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vertex_rag_data_service.UpdateRagEngineConfigRequest() + + request.rag_engine_config.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.update_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "rag_engine_config.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_rag_engine_config_field_headers_async(): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vertex_rag_data_service.UpdateRagEngineConfigRequest() + + request.rag_engine_config.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.update_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "rag_engine_config.name=name_value", + ) in kw["metadata"] + + +def test_update_rag_engine_config_flattened(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_rag_engine_config( + rag_engine_config=vertex_rag_data.RagEngineConfig(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].rag_engine_config + mock_val = vertex_rag_data.RagEngineConfig(name="name_value") + assert arg == mock_val + + +def test_update_rag_engine_config_flattened_error(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_rag_engine_config( + vertex_rag_data_service.UpdateRagEngineConfigRequest(), + rag_engine_config=vertex_rag_data.RagEngineConfig(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_update_rag_engine_config_flattened_async(): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_rag_engine_config( + rag_engine_config=vertex_rag_data.RagEngineConfig(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].rag_engine_config + mock_val = vertex_rag_data.RagEngineConfig(name="name_value") + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_rag_engine_config_flattened_error_async(): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_rag_engine_config( + vertex_rag_data_service.UpdateRagEngineConfigRequest(), + rag_engine_config=vertex_rag_data.RagEngineConfig(name="name_value"), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.GetRagEngineConfigRequest, + dict, + ], +) +def test_get_rag_engine_config(request_type, transport: str = "grpc"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = vertex_rag_data.RagEngineConfig( + name="name_value", + ) + response = client.get_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.GetRagEngineConfigRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, vertex_rag_data.RagEngineConfig) + assert response.name == "name_value" + + +def test_get_rag_engine_config_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = vertex_rag_data_service.GetRagEngineConfigRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_rag_engine_config(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == vertex_rag_data_service.GetRagEngineConfigRequest( + name="name_value", + ) + + +def test_get_rag_engine_config_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_rag_engine_config + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_rag_engine_config + ] = mock_rpc + request = {} + client.get_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_rag_engine_config(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_rag_engine_config_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_rag_engine_config + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_rag_engine_config + ] = mock_rpc + + request = {} + await client.get_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_rag_engine_config(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_rag_engine_config_async( + transport: str = "grpc_asyncio", + request_type=vertex_rag_data_service.GetRagEngineConfigRequest, +): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vertex_rag_data.RagEngineConfig( + name="name_value", + ) + ) + response = await client.get_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = vertex_rag_data_service.GetRagEngineConfigRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, vertex_rag_data.RagEngineConfig) + assert response.name == "name_value" + + +@pytest.mark.asyncio +async def test_get_rag_engine_config_async_from_dict(): + await test_get_rag_engine_config_async(request_type=dict) + + +def test_get_rag_engine_config_field_headers(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vertex_rag_data_service.GetRagEngineConfigRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + call.return_value = vertex_rag_data.RagEngineConfig() + client.get_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_rag_engine_config_field_headers_async(): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vertex_rag_data_service.GetRagEngineConfigRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vertex_rag_data.RagEngineConfig() + ) + await client.get_rag_engine_config(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_rag_engine_config_flattened(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = vertex_rag_data.RagEngineConfig() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_rag_engine_config( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_rag_engine_config_flattened_error(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_rag_engine_config( + vertex_rag_data_service.GetRagEngineConfigRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_rag_engine_config_flattened_async(): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = vertex_rag_data.RagEngineConfig() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vertex_rag_data.RagEngineConfig() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_rag_engine_config( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_rag_engine_config_flattened_error_async(): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_rag_engine_config( + vertex_rag_data_service.GetRagEngineConfigRequest(), + name="name_value", + ) + + def test_create_rag_corpus_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call @@ -5098,41 +5784,427 @@ def test_create_rag_corpus_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.create_rag_corpus in client._transport._wrapped_methods + assert client._transport.create_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_rag_corpus + ] = mock_rpc + + request = {} + client.create_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_rag_corpus_rest_required_fields( + request_type=vertex_rag_data_service.CreateRagCorpusRequest, +): + transport_class = transports.VertexRagDataServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_rag_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_rag_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.create_rag_corpus(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_rag_corpus_rest_unset_required_fields(): + transport = transports.VertexRagDataServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_rag_corpus._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "ragCorpus", + ) + ) + ) + + +def test_create_rag_corpus_rest_flattened(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/locations/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + rag_corpus=vertex_rag_data.RagCorpus( + vector_db_config=vertex_rag_data.RagVectorDbConfig( + rag_managed_db=vertex_rag_data.RagVectorDbConfig.RagManagedDb( + knn=None + ) + ) + ), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.create_rag_corpus(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/locations/*}/ragCorpora" % client.transport._host, + args[1], + ) + + +def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_rag_corpus( + vertex_rag_data_service.CreateRagCorpusRequest(), + parent="parent_value", + rag_corpus=vertex_rag_data.RagCorpus( + vector_db_config=vertex_rag_data.RagVectorDbConfig( + rag_managed_db=vertex_rag_data.RagVectorDbConfig.RagManagedDb( + knn=None + ) + ) + ), + ) + + +def test_update_rag_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_rag_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_rag_corpus + ] = mock_rpc + + request = {} + client.update_rag_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_rag_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_rag_corpus_rest_required_fields( + request_type=vertex_rag_data_service.UpdateRagCorpusRequest, +): + transport_class = transports.VertexRagDataServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_rag_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_rag_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.update_rag_corpus(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_rag_corpus_rest_unset_required_fields(): + transport = transports.VertexRagDataServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_rag_corpus._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("ragCorpus",))) + + +def test_update_rag_corpus_rest_flattened(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = { + "rag_corpus": { + "name": "projects/sample1/locations/sample2/ragCorpora/sample3" + } + } + + # get truthy value for each flattened field + mock_args = dict( + rag_corpus=vertex_rag_data.RagCorpus( + vector_db_config=vertex_rag_data.RagVectorDbConfig( + rag_managed_db=vertex_rag_data.RagVectorDbConfig.RagManagedDb( + knn=None + ) + ) + ), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.update_rag_corpus(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{rag_corpus.name=projects/*/locations/*/ragCorpora/*}" + % client.transport._host, + args[1], + ) + + +def test_update_rag_corpus_rest_flattened_error(transport: str = "rest"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_rag_corpus( + vertex_rag_data_service.UpdateRagCorpusRequest(), + rag_corpus=vertex_rag_data.RagCorpus( + vector_db_config=vertex_rag_data.RagVectorDbConfig( + rag_managed_db=vertex_rag_data.RagVectorDbConfig.RagManagedDb( + knn=None + ) + ) + ), + ) + + +def test_get_rag_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_rag_corpus in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[ - client._transport.create_rag_corpus - ] = mock_rpc + client._transport._wrapped_methods[client._transport.get_rag_corpus] = mock_rpc request = {} - client.create_rag_corpus(request) + client.get_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods build a cached wrapper on first rpc call - # subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - client.create_rag_corpus(request) + client.get_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_create_rag_corpus_rest_required_fields( - request_type=vertex_rag_data_service.CreateRagCorpusRequest, +def test_get_rag_corpus_rest_required_fields( + request_type=vertex_rag_data_service.GetRagCorpusRequest, ): transport_class = transports.VertexRagDataServiceRestTransport request_init = {} - request_init["parent"] = "" + request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -5143,21 +6215,21 @@ def test_create_rag_corpus_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_rag_corpus._get_unset_required_fields(jsonified_request) + ).get_rag_corpus._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["name"] = "name_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).create_rag_corpus._get_unset_required_fields(jsonified_request) + ).get_rag_corpus._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5166,7 +6238,7 @@ def test_create_rag_corpus_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data.RagCorpus() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5178,45 +6250,39 @@ def test_create_rag_corpus_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "get", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data.RagCorpus.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.create_rag_corpus(request) + response = client.get_rag_corpus(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_create_rag_corpus_rest_unset_required_fields(): +def test_get_rag_corpus_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.create_rag_corpus._get_unset_required_fields({}) - assert set(unset_fields) == ( - set(()) - & set( - ( - "parent", - "ragCorpus", - ) - ) - ) + unset_fields = transport.get_rag_corpus._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) -def test_create_rag_corpus_rest_flattened(): +def test_get_rag_corpus_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5225,45 +6291,42 @@ def test_create_rag_corpus_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data.RagCorpus() # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "name": "projects/sample1/locations/sample2/ragCorpora/sample3" + } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", - rag_corpus=vertex_rag_data.RagCorpus( - vector_db_config=vertex_rag_data.RagVectorDbConfig( - rag_managed_db=vertex_rag_data.RagVectorDbConfig.RagManagedDb( - knn=None - ) - ) - ), + name="name_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = vertex_rag_data.RagCorpus.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.create_rag_corpus(**mock_args) + client.get_rag_corpus(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1/{parent=projects/*/locations/*}/ragCorpora" % client.transport._host, + "%s/v1/{name=projects/*/locations/*/ragCorpora/*}" % client.transport._host, args[1], ) -def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"): +def test_get_rag_corpus_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5272,20 +6335,13 @@ def test_create_rag_corpus_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.create_rag_corpus( - vertex_rag_data_service.CreateRagCorpusRequest(), - parent="parent_value", - rag_corpus=vertex_rag_data.RagCorpus( - vector_db_config=vertex_rag_data.RagVectorDbConfig( - rag_managed_db=vertex_rag_data.RagVectorDbConfig.RagManagedDb( - knn=None - ) - ) - ), + client.get_rag_corpus( + vertex_rag_data_service.GetRagCorpusRequest(), + name="name_value", ) -def test_update_rag_corpus_rest_use_cached_wrapped_rpc(): +def test_list_rag_corpora_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -5299,7 +6355,7 @@ def test_update_rag_corpus_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.update_rag_corpus in client._transport._wrapped_methods + assert client._transport.list_rag_corpora in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() @@ -5307,32 +6363,29 @@ def test_update_rag_corpus_rest_use_cached_wrapped_rpc(): "foo" # operation_request.operation in compute client(s) expect a string. ) client._transport._wrapped_methods[ - client._transport.update_rag_corpus + client._transport.list_rag_corpora ] = mock_rpc request = {} - client.update_rag_corpus(request) + client.list_rag_corpora(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods build a cached wrapper on first rpc call - # subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - client.update_rag_corpus(request) + client.list_rag_corpora(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_update_rag_corpus_rest_required_fields( - request_type=vertex_rag_data_service.UpdateRagCorpusRequest, +def test_list_rag_corpora_rest_required_fields( + request_type=vertex_rag_data_service.ListRagCorporaRequest, ): transport_class = transports.VertexRagDataServiceRestTransport request_init = {} + request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -5343,17 +6396,28 @@ def test_update_rag_corpus_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).update_rag_corpus._get_unset_required_fields(jsonified_request) + ).list_rag_corpora._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present + jsonified_request["parent"] = "parent_value" + unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).update_rag_corpus._get_unset_required_fields(jsonified_request) + ).list_rag_corpora._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -5362,7 +6426,7 @@ def test_update_rag_corpus_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data_service.ListRagCorporaResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5374,37 +6438,49 @@ def test_update_rag_corpus_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "patch", + "method": "get", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data_service.ListRagCorporaResponse.pb( + return_value + ) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.update_rag_corpus(request) + response = client.list_rag_corpora(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_update_rag_corpus_rest_unset_required_fields(): +def test_list_rag_corpora_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.update_rag_corpus._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("ragCorpus",))) + unset_fields = transport.list_rag_corpora._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) -def test_update_rag_corpus_rest_flattened(): +def test_list_rag_corpora_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5413,49 +6489,40 @@ def test_update_rag_corpus_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data_service.ListRagCorporaResponse() # get arguments that satisfy an http rule for this method - sample_request = { - "rag_corpus": { - "name": "projects/sample1/locations/sample2/ragCorpora/sample3" - } - } + sample_request = {"parent": "projects/sample1/locations/sample2"} # get truthy value for each flattened field mock_args = dict( - rag_corpus=vertex_rag_data.RagCorpus( - vector_db_config=vertex_rag_data.RagVectorDbConfig( - rag_managed_db=vertex_rag_data.RagVectorDbConfig.RagManagedDb( - knn=None - ) - ) - ), + parent="parent_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = vertex_rag_data_service.ListRagCorporaResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.update_rag_corpus(**mock_args) + client.list_rag_corpora(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1/{rag_corpus.name=projects/*/locations/*/ragCorpora/*}" - % client.transport._host, + "%s/v1/{parent=projects/*/locations/*}/ragCorpora" % client.transport._host, args[1], ) -def test_update_rag_corpus_rest_flattened_error(transport: str = "rest"): +def test_list_rag_corpora_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5464,19 +6531,76 @@ def test_update_rag_corpus_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.update_rag_corpus( - vertex_rag_data_service.UpdateRagCorpusRequest(), - rag_corpus=vertex_rag_data.RagCorpus( - vector_db_config=vertex_rag_data.RagVectorDbConfig( - rag_managed_db=vertex_rag_data.RagVectorDbConfig.RagManagedDb( - knn=None - ) - ) + client.list_rag_corpora( + vertex_rag_data_service.ListRagCorporaRequest(), + parent="parent_value", + ) + + +def test_list_rag_corpora_rest_pager(transport: str = "rest"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagCorporaResponse( + rag_corpora=[ + vertex_rag_data.RagCorpus(), + vertex_rag_data.RagCorpus(), + ], ), ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + vertex_rag_data_service.ListRagCorporaResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/locations/sample2"} + + pager = client.list_rag_corpora(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in results) + + pages = list(client.list_rag_corpora(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token -def test_get_rag_corpus_rest_use_cached_wrapped_rpc(): +def test_delete_rag_corpus_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -5490,30 +6614,36 @@ def test_get_rag_corpus_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.get_rag_corpus in client._transport._wrapped_methods + assert client._transport.delete_rag_corpus in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.get_rag_corpus] = mock_rpc + client._transport._wrapped_methods[ + client._transport.delete_rag_corpus + ] = mock_rpc request = {} - client.get_rag_corpus(request) + client.delete_rag_corpus(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.get_rag_corpus(request) + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_rag_corpus(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_get_rag_corpus_rest_required_fields( - request_type=vertex_rag_data_service.GetRagCorpusRequest, +def test_delete_rag_corpus_rest_required_fields( + request_type=vertex_rag_data_service.DeleteRagCorpusRequest, ): transport_class = transports.VertexRagDataServiceRestTransport @@ -5529,7 +6659,7 @@ def test_get_rag_corpus_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_rag_corpus._get_unset_required_fields(jsonified_request) + ).delete_rag_corpus._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -5538,7 +6668,9 @@ def test_get_rag_corpus_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_rag_corpus._get_unset_required_fields(jsonified_request) + ).delete_rag_corpus._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("force",)) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -5552,7 +6684,7 @@ def test_get_rag_corpus_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = vertex_rag_data.RagCorpus() + return_value = operations_pb2.Operation(name="operations/spam") # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5564,39 +6696,36 @@ def test_get_rag_corpus_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "delete", "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data.RagCorpus.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.get_rag_corpus(request) + response = client.delete_rag_corpus(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_get_rag_corpus_rest_unset_required_fields(): +def test_delete_rag_corpus_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.get_rag_corpus._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name",))) + unset_fields = transport.delete_rag_corpus._get_unset_required_fields({}) + assert set(unset_fields) == (set(("force",)) & set(("name",))) -def test_get_rag_corpus_rest_flattened(): +def test_delete_rag_corpus_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5605,7 +6734,7 @@ def test_get_rag_corpus_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data.RagCorpus() + return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method sample_request = { @@ -5621,14 +6750,12 @@ def test_get_rag_corpus_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = vertex_rag_data.RagCorpus.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.get_rag_corpus(**mock_args) + client.delete_rag_corpus(**mock_args) # Establish that the underlying call was made with the expected # request object values. @@ -5640,7 +6767,7 @@ def test_get_rag_corpus_rest_flattened(): ) -def test_get_rag_corpus_rest_flattened_error(transport: str = "rest"): +def test_delete_rag_corpus_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5649,13 +6776,13 @@ def test_get_rag_corpus_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_rag_corpus( - vertex_rag_data_service.GetRagCorpusRequest(), + client.delete_rag_corpus( + vertex_rag_data_service.DeleteRagCorpusRequest(), name="name_value", ) -def test_list_rag_corpora_rest_use_cached_wrapped_rpc(): +def test_upload_rag_file_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -5669,32 +6796,30 @@ def test_list_rag_corpora_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.list_rag_corpora in client._transport._wrapped_methods + assert client._transport.upload_rag_file in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[ - client._transport.list_rag_corpora - ] = mock_rpc + client._transport._wrapped_methods[client._transport.upload_rag_file] = mock_rpc request = {} - client.list_rag_corpora(request) + client.upload_rag_file(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.list_rag_corpora(request) + client.upload_rag_file(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_list_rag_corpora_rest_required_fields( - request_type=vertex_rag_data_service.ListRagCorporaRequest, +def test_upload_rag_file_rest_required_fields( + request_type=vertex_rag_data_service.UploadRagFileRequest, ): transport_class = transports.VertexRagDataServiceRestTransport @@ -5710,7 +6835,7 @@ def test_list_rag_corpora_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_rag_corpora._get_unset_required_fields(jsonified_request) + ).upload_rag_file._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -5719,14 +6844,7 @@ def test_list_rag_corpora_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_rag_corpora._get_unset_required_fields(jsonified_request) - # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set( - ( - "page_size", - "page_token", - ) - ) + ).upload_rag_file._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -5740,7 +6858,7 @@ def test_list_rag_corpora_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.ListRagCorporaResponse() + return_value = vertex_rag_data_service.UploadRagFileResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -5752,16 +6870,17 @@ def test_list_rag_corpora_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "post", "query_params": pb_request, } + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = vertex_rag_data_service.ListRagCorporaResponse.pb( + return_value = vertex_rag_data_service.UploadRagFileResponse.pb( return_value ) json_return_value = json_format.MessageToJson(return_value) @@ -5770,31 +6889,32 @@ def test_list_rag_corpora_rest_required_fields( req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.list_rag_corpora(request) + response = client.upload_rag_file(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_list_rag_corpora_rest_unset_required_fields(): +def test_upload_rag_file_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.list_rag_corpora._get_unset_required_fields({}) + unset_fields = transport.upload_rag_file._get_unset_required_fields({}) assert set(unset_fields) == ( - set( + set(()) + & set( ( - "pageSize", - "pageToken", + "parent", + "ragFile", + "uploadRagFileConfig", ) ) - & set(("parent",)) ) -def test_list_rag_corpora_rest_flattened(): +def test_upload_rag_file_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -5803,14 +6923,28 @@ def test_list_rag_corpora_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.ListRagCorporaResponse() + return_value = vertex_rag_data_service.UploadRagFileResponse() # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "parent": "projects/sample1/locations/sample2/ragCorpora/sample3" + } # get truthy value for each flattened field mock_args = dict( parent="parent_value", + rag_file=vertex_rag_data.RagFile( + gcs_source=io.GcsSource(uris=["uris_value"]) + ), + upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( + rag_file_transformation_config=vertex_rag_data.RagFileTransformationConfig( + rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( + fixed_length_chunking=vertex_rag_data.RagFileChunkingConfig.FixedLengthChunking( + chunk_size=1075 + ) + ) + ) + ), ) mock_args.update(sample_request) @@ -5818,25 +6952,26 @@ def test_list_rag_corpora_rest_flattened(): response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = vertex_rag_data_service.ListRagCorporaResponse.pb(return_value) + return_value = vertex_rag_data_service.UploadRagFileResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.list_rag_corpora(**mock_args) + client.upload_rag_file(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1/{parent=projects/*/locations/*}/ragCorpora" % client.transport._host, + "%s/v1/{parent=projects/*/locations/*/ragCorpora/*}/ragFiles:upload" + % client.transport._host, args[1], ) -def test_list_rag_corpora_rest_flattened_error(transport: str = "rest"): +def test_upload_rag_file_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -5845,76 +6980,25 @@ def test_list_rag_corpora_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_rag_corpora( - vertex_rag_data_service.ListRagCorporaRequest(), + client.upload_rag_file( + vertex_rag_data_service.UploadRagFileRequest(), parent="parent_value", - ) - - -def test_list_rag_corpora_rest_pager(transport: str = "rest"): - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - # with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - ], - next_page_token="ghi", + rag_file=vertex_rag_data.RagFile( + gcs_source=io.GcsSource(uris=["uris_value"]) ), - vertex_rag_data_service.ListRagCorporaResponse( - rag_corpora=[ - vertex_rag_data.RagCorpus(), - vertex_rag_data.RagCorpus(), - ], + upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( + rag_file_transformation_config=vertex_rag_data.RagFileTransformationConfig( + rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( + fixed_length_chunking=vertex_rag_data.RagFileChunkingConfig.FixedLengthChunking( + chunk_size=1075 + ) + ) + ) ), ) - # Two responses for two calls - response = response + response - - # Wrap the values into proper Response objs - response = tuple( - vertex_rag_data_service.ListRagCorporaResponse.to_json(x) for x in response - ) - return_values = tuple(Response() for i in response) - for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode("UTF-8") - return_val.status_code = 200 - req.side_effect = return_values - - sample_request = {"parent": "projects/sample1/locations/sample2"} - - pager = client.list_rag_corpora(request=sample_request) - - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, vertex_rag_data.RagCorpus) for i in results) - - pages = list(client.list_rag_corpora(request=sample_request).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token -def test_delete_rag_corpus_rest_use_cached_wrapped_rpc(): +def test_import_rag_files_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -5928,7 +7012,7 @@ def test_delete_rag_corpus_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.delete_rag_corpus in client._transport._wrapped_methods + assert client._transport.import_rag_files in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() @@ -5936,11 +7020,11 @@ def test_delete_rag_corpus_rest_use_cached_wrapped_rpc(): "foo" # operation_request.operation in compute client(s) expect a string. ) client._transport._wrapped_methods[ - client._transport.delete_rag_corpus + client._transport.import_rag_files ] = mock_rpc request = {} - client.delete_rag_corpus(request) + client.import_rag_files(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 @@ -5949,20 +7033,20 @@ def test_delete_rag_corpus_rest_use_cached_wrapped_rpc(): # subsequent calls should use the cached wrapper wrapper_fn.reset_mock() - client.delete_rag_corpus(request) + client.import_rag_files(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_delete_rag_corpus_rest_required_fields( - request_type=vertex_rag_data_service.DeleteRagCorpusRequest, +def test_import_rag_files_rest_required_fields( + request_type=vertex_rag_data_service.ImportRagFilesRequest, ): transport_class = transports.VertexRagDataServiceRestTransport request_init = {} - request_init["name"] = "" + request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -5973,23 +7057,21 @@ def test_delete_rag_corpus_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_rag_corpus._get_unset_required_fields(jsonified_request) + ).import_rag_files._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = "name_value" + jsonified_request["parent"] = "parent_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_rag_corpus._get_unset_required_fields(jsonified_request) - # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set(("force",)) + ).import_rag_files._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "name" in jsonified_request - assert jsonified_request["name"] == "name_value" + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6010,9 +7092,10 @@ def test_delete_rag_corpus_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "delete", + "method": "post", "query_params": pb_request, } + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -6023,23 +7106,31 @@ def test_delete_rag_corpus_rest_required_fields( req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.delete_rag_corpus(request) + response = client.import_rag_files(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_delete_rag_corpus_rest_unset_required_fields(): +def test_import_rag_files_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.delete_rag_corpus._get_unset_required_fields({}) - assert set(unset_fields) == (set(("force",)) & set(("name",))) + unset_fields = transport.import_rag_files._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "importRagFilesConfig", + ) + ) + ) -def test_delete_rag_corpus_rest_flattened(): +def test_import_rag_files_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6052,12 +7143,15 @@ def test_delete_rag_corpus_rest_flattened(): # get arguments that satisfy an http rule for this method sample_request = { - "name": "projects/sample1/locations/sample2/ragCorpora/sample3" + "parent": "projects/sample1/locations/sample2/ragCorpora/sample3" } # get truthy value for each flattened field mock_args = dict( - name="name_value", + parent="parent_value", + import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + gcs_source=io.GcsSource(uris=["uris_value"]) + ), ) mock_args.update(sample_request) @@ -6069,19 +7163,20 @@ def test_delete_rag_corpus_rest_flattened(): req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.delete_rag_corpus(**mock_args) + client.import_rag_files(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1/{name=projects/*/locations/*/ragCorpora/*}" % client.transport._host, + "%s/v1/{parent=projects/*/locations/*/ragCorpora/*}/ragFiles:import" + % client.transport._host, args[1], ) -def test_delete_rag_corpus_rest_flattened_error(transport: str = "rest"): +def test_import_rag_files_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6090,13 +7185,16 @@ def test_delete_rag_corpus_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_rag_corpus( - vertex_rag_data_service.DeleteRagCorpusRequest(), - name="name_value", + client.import_rag_files( + vertex_rag_data_service.ImportRagFilesRequest(), + parent="parent_value", + import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( + gcs_source=io.GcsSource(uris=["uris_value"]) + ), ) -def test_upload_rag_file_rest_use_cached_wrapped_rpc(): +def test_get_rag_file_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -6110,35 +7208,35 @@ def test_upload_rag_file_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.upload_rag_file in client._transport._wrapped_methods + assert client._transport.get_rag_file in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.upload_rag_file] = mock_rpc + client._transport._wrapped_methods[client._transport.get_rag_file] = mock_rpc request = {} - client.upload_rag_file(request) + client.get_rag_file(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.upload_rag_file(request) + client.get_rag_file(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_upload_rag_file_rest_required_fields( - request_type=vertex_rag_data_service.UploadRagFileRequest, +def test_get_rag_file_rest_required_fields( + request_type=vertex_rag_data_service.GetRagFileRequest, ): transport_class = transports.VertexRagDataServiceRestTransport request_init = {} - request_init["parent"] = "" + request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -6149,21 +7247,21 @@ def test_upload_rag_file_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).upload_rag_file._get_unset_required_fields(jsonified_request) + ).get_rag_file._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" + jsonified_request["name"] = "name_value" unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).upload_rag_file._get_unset_required_fields(jsonified_request) + ).get_rag_file._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6172,7 +7270,7 @@ def test_upload_rag_file_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.UploadRagFileResponse() + return_value = vertex_rag_data.RagFile() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -6184,51 +7282,39 @@ def test_upload_rag_file_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "get", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = vertex_rag_data_service.UploadRagFileResponse.pb( - return_value - ) + return_value = vertex_rag_data.RagFile.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.upload_rag_file(request) + response = client.get_rag_file(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_upload_rag_file_rest_unset_required_fields(): +def test_get_rag_file_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.upload_rag_file._get_unset_required_fields({}) - assert set(unset_fields) == ( - set(()) - & set( - ( - "parent", - "ragFile", - "uploadRagFileConfig", - ) - ) - ) + unset_fields = transport.get_rag_file._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) -def test_upload_rag_file_rest_flattened(): +def test_get_rag_file_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6237,28 +7323,16 @@ def test_upload_rag_file_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.UploadRagFileResponse() + return_value = vertex_rag_data.RagFile() # get arguments that satisfy an http rule for this method sample_request = { - "parent": "projects/sample1/locations/sample2/ragCorpora/sample3" + "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", - rag_file=vertex_rag_data.RagFile( - gcs_source=io.GcsSource(uris=["uris_value"]) - ), - upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( - rag_file_transformation_config=vertex_rag_data.RagFileTransformationConfig( - rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( - fixed_length_chunking=vertex_rag_data.RagFileChunkingConfig.FixedLengthChunking( - chunk_size=1075 - ) - ) - ) - ), + name="name_value", ) mock_args.update(sample_request) @@ -6266,26 +7340,26 @@ def test_upload_rag_file_rest_flattened(): response_value = Response() response_value.status_code = 200 # Convert return value to protobuf type - return_value = vertex_rag_data_service.UploadRagFileResponse.pb(return_value) + return_value = vertex_rag_data.RagFile.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.upload_rag_file(**mock_args) + client.get_rag_file(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1/{parent=projects/*/locations/*/ragCorpora/*}/ragFiles:upload" + "%s/v1/{name=projects/*/locations/*/ragCorpora/*/ragFiles/*}" % client.transport._host, args[1], ) -def test_upload_rag_file_rest_flattened_error(transport: str = "rest"): +def test_get_rag_file_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6294,25 +7368,13 @@ def test_upload_rag_file_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.upload_rag_file( - vertex_rag_data_service.UploadRagFileRequest(), - parent="parent_value", - rag_file=vertex_rag_data.RagFile( - gcs_source=io.GcsSource(uris=["uris_value"]) - ), - upload_rag_file_config=vertex_rag_data.UploadRagFileConfig( - rag_file_transformation_config=vertex_rag_data.RagFileTransformationConfig( - rag_file_chunking_config=vertex_rag_data.RagFileChunkingConfig( - fixed_length_chunking=vertex_rag_data.RagFileChunkingConfig.FixedLengthChunking( - chunk_size=1075 - ) - ) - ) - ), + client.get_rag_file( + vertex_rag_data_service.GetRagFileRequest(), + name="name_value", ) -def test_import_rag_files_rest_use_cached_wrapped_rpc(): +def test_list_rag_files_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -6326,36 +7388,30 @@ def test_import_rag_files_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.import_rag_files in client._transport._wrapped_methods + assert client._transport.list_rag_files in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[ - client._transport.import_rag_files - ] = mock_rpc + client._transport._wrapped_methods[client._transport.list_rag_files] = mock_rpc request = {} - client.import_rag_files(request) + client.list_rag_files(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods build a cached wrapper on first rpc call - # subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - client.import_rag_files(request) + client.list_rag_files(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_import_rag_files_rest_required_fields( - request_type=vertex_rag_data_service.ImportRagFilesRequest, +def test_list_rag_files_rest_required_fields( + request_type=vertex_rag_data_service.ListRagFilesRequest, ): transport_class = transports.VertexRagDataServiceRestTransport @@ -6371,7 +7427,7 @@ def test_import_rag_files_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).import_rag_files._get_unset_required_fields(jsonified_request) + ).list_rag_files._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -6380,7 +7436,14 @@ def test_import_rag_files_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).import_rag_files._get_unset_required_fields(jsonified_request) + ).list_rag_files._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -6394,7 +7457,7 @@ def test_import_rag_files_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data_service.ListRagFilesResponse() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -6406,45 +7469,47 @@ def test_import_rag_files_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "post", + "method": "get", "query_params": pb_request, } - transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data_service.ListRagFilesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.import_rag_files(request) + response = client.list_rag_files(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_import_rag_files_rest_unset_required_fields(): +def test_list_rag_files_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.import_rag_files._get_unset_required_fields({}) + unset_fields = transport.list_rag_files._get_unset_required_fields({}) assert set(unset_fields) == ( - set(()) - & set( + set( ( - "parent", - "importRagFilesConfig", + "pageSize", + "pageToken", ) ) + & set(("parent",)) ) -def test_import_rag_files_rest_flattened(): +def test_list_rag_files_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6453,7 +7518,7 @@ def test_import_rag_files_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data_service.ListRagFilesResponse() # get arguments that satisfy an http rule for this method sample_request = { @@ -6463,34 +7528,33 @@ def test_import_rag_files_rest_flattened(): # get truthy value for each flattened field mock_args = dict( parent="parent_value", - import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( - gcs_source=io.GcsSource(uris=["uris_value"]) - ), ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = vertex_rag_data_service.ListRagFilesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.import_rag_files(**mock_args) + client.list_rag_files(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1/{parent=projects/*/locations/*/ragCorpora/*}/ragFiles:import" + "%s/v1/{parent=projects/*/locations/*/ragCorpora/*}/ragFiles" % client.transport._host, args[1], ) -def test_import_rag_files_rest_flattened_error(transport: str = "rest"): +def test_list_rag_files_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6499,16 +7563,78 @@ def test_import_rag_files_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.import_rag_files( - vertex_rag_data_service.ImportRagFilesRequest(), + client.list_rag_files( + vertex_rag_data_service.ListRagFilesRequest(), parent="parent_value", - import_rag_files_config=vertex_rag_data.ImportRagFilesConfig( - gcs_source=io.GcsSource(uris=["uris_value"]) + ) + + +def test_list_rag_files_rest_pager(transport: str = "rest"): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], + next_page_token="abc", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[], + next_page_token="def", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + ], + next_page_token="ghi", + ), + vertex_rag_data_service.ListRagFilesResponse( + rag_files=[ + vertex_rag_data.RagFile(), + vertex_rag_data.RagFile(), + ], ), ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + vertex_rag_data_service.ListRagFilesResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = { + "parent": "projects/sample1/locations/sample2/ragCorpora/sample3" + } + + pager = client.list_rag_files(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, vertex_rag_data.RagFile) for i in results) + + pages = list(client.list_rag_files(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token -def test_get_rag_file_rest_use_cached_wrapped_rpc(): +def test_delete_rag_file_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -6522,30 +7648,34 @@ def test_get_rag_file_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.get_rag_file in client._transport._wrapped_methods + assert client._transport.delete_rag_file in client._transport._wrapped_methods # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.get_rag_file] = mock_rpc + client._transport._wrapped_methods[client._transport.delete_rag_file] = mock_rpc request = {} - client.get_rag_file(request) + client.delete_rag_file(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.get_rag_file(request) + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.delete_rag_file(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_get_rag_file_rest_required_fields( - request_type=vertex_rag_data_service.GetRagFileRequest, +def test_delete_rag_file_rest_required_fields( + request_type=vertex_rag_data_service.DeleteRagFileRequest, ): transport_class = transports.VertexRagDataServiceRestTransport @@ -6561,7 +7691,7 @@ def test_get_rag_file_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_rag_file._get_unset_required_fields(jsonified_request) + ).delete_rag_file._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -6570,7 +7700,7 @@ def test_get_rag_file_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).get_rag_file._get_unset_required_fields(jsonified_request) + ).delete_rag_file._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -6584,7 +7714,7 @@ def test_get_rag_file_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = vertex_rag_data.RagFile() + return_value = operations_pb2.Operation(name="operations/spam") # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -6596,39 +7726,36 @@ def test_get_rag_file_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "delete", "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data.RagFile.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.get_rag_file(request) + response = client.delete_rag_file(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_get_rag_file_rest_unset_required_fields(): +def test_delete_rag_file_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.get_rag_file._get_unset_required_fields({}) + unset_fields = transport.delete_rag_file._get_unset_required_fields({}) assert set(unset_fields) == (set(()) & set(("name",))) -def test_get_rag_file_rest_flattened(): +def test_delete_rag_file_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6637,7 +7764,7 @@ def test_get_rag_file_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data.RagFile() + return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method sample_request = { @@ -6653,14 +7780,12 @@ def test_get_rag_file_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = vertex_rag_data.RagFile.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.get_rag_file(**mock_args) + client.delete_rag_file(**mock_args) # Establish that the underlying call was made with the expected # request object values. @@ -6673,7 +7798,7 @@ def test_get_rag_file_rest_flattened(): ) -def test_get_rag_file_rest_flattened_error(transport: str = "rest"): +def test_delete_rag_file_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6682,13 +7807,13 @@ def test_get_rag_file_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.get_rag_file( - vertex_rag_data_service.GetRagFileRequest(), + client.delete_rag_file( + vertex_rag_data_service.DeleteRagFileRequest(), name="name_value", ) -def test_list_rag_files_rest_use_cached_wrapped_rpc(): +def test_update_rag_engine_config_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -6702,35 +7827,43 @@ def test_list_rag_files_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.list_rag_files in client._transport._wrapped_methods + assert ( + client._transport.update_rag_engine_config + in client._transport._wrapped_methods + ) # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.list_rag_files] = mock_rpc + client._transport._wrapped_methods[ + client._transport.update_rag_engine_config + ] = mock_rpc request = {} - client.list_rag_files(request) + client.update_rag_engine_config(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - client.list_rag_files(request) + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.update_rag_engine_config(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_list_rag_files_rest_required_fields( - request_type=vertex_rag_data_service.ListRagFilesRequest, +def test_update_rag_engine_config_rest_required_fields( + request_type=vertex_rag_data_service.UpdateRagEngineConfigRequest, ): transport_class = transports.VertexRagDataServiceRestTransport request_init = {} - request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) jsonified_request = json.loads( @@ -6741,28 +7874,17 @@ def test_list_rag_files_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_rag_files._get_unset_required_fields(jsonified_request) + ).update_rag_engine_config._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = "parent_value" - unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).list_rag_files._get_unset_required_fields(jsonified_request) - # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set( - ( - "page_size", - "page_token", - ) - ) + ).update_rag_engine_config._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone - assert "parent" in jsonified_request - assert jsonified_request["parent"] == "parent_value" client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -6771,7 +7893,7 @@ def test_list_rag_files_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.ListRagFilesResponse() + return_value = operations_pb2.Operation(name="operations/spam") # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -6783,47 +7905,37 @@ def test_list_rag_files_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "get", + "method": "patch", "query_params": pb_request, } + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data_service.ListRagFilesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.list_rag_files(request) + response = client.update_rag_engine_config(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_list_rag_files_rest_unset_required_fields(): +def test_update_rag_engine_config_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.list_rag_files._get_unset_required_fields({}) - assert set(unset_fields) == ( - set( - ( - "pageSize", - "pageToken", - ) - ) - & set(("parent",)) - ) + unset_fields = transport.update_rag_engine_config._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("ragEngineConfig",))) -def test_list_rag_files_rest_flattened(): +def test_update_rag_engine_config_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -6832,43 +7944,43 @@ def test_list_rag_files_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.ListRagFilesResponse() + return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method sample_request = { - "parent": "projects/sample1/locations/sample2/ragCorpora/sample3" + "rag_engine_config": { + "name": "projects/sample1/locations/sample2/ragEngineConfig" + } } # get truthy value for each flattened field mock_args = dict( - parent="parent_value", + rag_engine_config=vertex_rag_data.RagEngineConfig(name="name_value"), ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - # Convert return value to protobuf type - return_value = vertex_rag_data_service.ListRagFilesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.list_rag_files(**mock_args) + client.update_rag_engine_config(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1/{parent=projects/*/locations/*/ragCorpora/*}/ragFiles" + "%s/v1/{rag_engine_config.name=projects/*/locations/*/ragEngineConfig}" % client.transport._host, args[1], ) -def test_list_rag_files_rest_flattened_error(transport: str = "rest"): +def test_update_rag_engine_config_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -6877,78 +7989,13 @@ def test_list_rag_files_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.list_rag_files( - vertex_rag_data_service.ListRagFilesRequest(), - parent="parent_value", - ) - - -def test_list_rag_files_rest_pager(transport: str = "rest"): - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, "request") as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - # with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - next_page_token="abc", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[], - next_page_token="def", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - ], - next_page_token="ghi", - ), - vertex_rag_data_service.ListRagFilesResponse( - rag_files=[ - vertex_rag_data.RagFile(), - vertex_rag_data.RagFile(), - ], - ), - ) - # Two responses for two calls - response = response + response - - # Wrap the values into proper Response objs - response = tuple( - vertex_rag_data_service.ListRagFilesResponse.to_json(x) for x in response + client.update_rag_engine_config( + vertex_rag_data_service.UpdateRagEngineConfigRequest(), + rag_engine_config=vertex_rag_data.RagEngineConfig(name="name_value"), ) - return_values = tuple(Response() for i in response) - for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode("UTF-8") - return_val.status_code = 200 - req.side_effect = return_values - - sample_request = { - "parent": "projects/sample1/locations/sample2/ragCorpora/sample3" - } - - pager = client.list_rag_files(request=sample_request) - - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, vertex_rag_data.RagFile) for i in results) - - pages = list(client.list_rag_files(request=sample_request).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token -def test_delete_rag_file_rest_use_cached_wrapped_rpc(): +def test_get_rag_engine_config_rest_use_cached_wrapped_rpc(): # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, # instead of constructing them on each call with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: @@ -6962,34 +8009,35 @@ def test_delete_rag_file_rest_use_cached_wrapped_rpc(): wrapper_fn.reset_mock() # Ensure method has been cached - assert client._transport.delete_rag_file in client._transport._wrapped_methods + assert ( + client._transport.get_rag_engine_config + in client._transport._wrapped_methods + ) # Replace cached wrapped function with mock mock_rpc = mock.Mock() mock_rpc.return_value.name = ( "foo" # operation_request.operation in compute client(s) expect a string. ) - client._transport._wrapped_methods[client._transport.delete_rag_file] = mock_rpc + client._transport._wrapped_methods[ + client._transport.get_rag_engine_config + ] = mock_rpc request = {} - client.delete_rag_file(request) + client.get_rag_engine_config(request) # Establish that the underlying gRPC stub method was called. assert mock_rpc.call_count == 1 - # Operation methods build a cached wrapper on first rpc call - # subsequent calls should use the cached wrapper - wrapper_fn.reset_mock() - - client.delete_rag_file(request) + client.get_rag_engine_config(request) # Establish that a new wrapper was not created for this call assert wrapper_fn.call_count == 0 assert mock_rpc.call_count == 2 -def test_delete_rag_file_rest_required_fields( - request_type=vertex_rag_data_service.DeleteRagFileRequest, +def test_get_rag_engine_config_rest_required_fields( + request_type=vertex_rag_data_service.GetRagEngineConfigRequest, ): transport_class = transports.VertexRagDataServiceRestTransport @@ -7005,7 +8053,7 @@ def test_delete_rag_file_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_rag_file._get_unset_required_fields(jsonified_request) + ).get_rag_engine_config._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present @@ -7014,7 +8062,7 @@ def test_delete_rag_file_rest_required_fields( unset_fields = transport_class( credentials=ga_credentials.AnonymousCredentials() - ).delete_rag_file._get_unset_required_fields(jsonified_request) + ).get_rag_engine_config._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone @@ -7028,7 +8076,7 @@ def test_delete_rag_file_rest_required_fields( request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data.RagEngineConfig() # Mock the http request call within the method and fake a response. with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values @@ -7040,36 +8088,39 @@ def test_delete_rag_file_rest_required_fields( pb_request = request_type.pb(request) transcode_result = { "uri": "v1/sample_method", - "method": "delete", + "method": "get", "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data.RagEngineConfig.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.delete_rag_file(request) + response = client.get_rag_engine_config(request) expected_params = [("$alt", "json;enum-encoding=int")] actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params -def test_delete_rag_file_rest_unset_required_fields(): +def test_get_rag_engine_config_rest_unset_required_fields(): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials ) - unset_fields = transport.delete_rag_file._get_unset_required_fields({}) + unset_fields = transport.get_rag_engine_config._get_unset_required_fields({}) assert set(unset_fields) == (set(()) & set(("name",))) -def test_delete_rag_file_rest_flattened(): +def test_get_rag_engine_config_rest_flattened(): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest", @@ -7078,12 +8129,10 @@ def test_delete_rag_file_rest_flattened(): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data.RagEngineConfig() # get arguments that satisfy an http rule for this method - sample_request = { - "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" - } + sample_request = {"name": "projects/sample1/locations/sample2/ragEngineConfig"} # get truthy value for each flattened field mock_args = dict( @@ -7094,25 +8143,27 @@ def test_delete_rag_file_rest_flattened(): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 + # Convert return value to protobuf type + return_value = vertex_rag_data.RagEngineConfig.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.delete_rag_file(**mock_args) + client.get_rag_engine_config(**mock_args) # Establish that the underlying call was made with the expected # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1/{name=projects/*/locations/*/ragCorpora/*/ragFiles/*}" + "%s/v1/{name=projects/*/locations/*/ragEngineConfig}" % client.transport._host, args[1], ) -def test_delete_rag_file_rest_flattened_error(transport: str = "rest"): +def test_get_rag_engine_config_rest_flattened_error(transport: str = "rest"): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7121,8 +8172,8 @@ def test_delete_rag_file_rest_flattened_error(transport: str = "rest"): # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): - client.delete_rag_file( - vertex_rag_data_service.DeleteRagFileRequest(), + client.get_rag_engine_config( + vertex_rag_data_service.GetRagEngineConfigRequest(), name="name_value", ) @@ -7444,7 +8495,53 @@ def test_delete_rag_file_empty_call_grpc(): # Establish that the underlying stub method was called. call.assert_called() _, args, _ = call.mock_calls[0] - request_msg = vertex_rag_data_service.DeleteRagFileRequest() + request_msg = vertex_rag_data_service.DeleteRagFileRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_rag_engine_config_empty_call_grpc(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.update_rag_engine_config(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.UpdateRagEngineConfigRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_rag_engine_config_empty_call_grpc(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + call.return_value = vertex_rag_data.RagEngineConfig() + client.get_rag_engine_config(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.GetRagEngineConfigRequest() assert args[0] == request_msg @@ -7721,31 +8818,315 @@ async def test_delete_rag_file_empty_call_grpc_asyncio(): call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") ) - await client.delete_rag_file(request=None) + await client.delete_rag_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.DeleteRagFileRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_update_rag_engine_config_empty_call_grpc_asyncio(): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + await client.update_rag_engine_config(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.UpdateRagEngineConfigRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_rag_engine_config_empty_call_grpc_asyncio(): + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vertex_rag_data.RagEngineConfig( + name="name_value", + ) + ) + await client.get_rag_engine_config(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.GetRagEngineConfigRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = VertexRagDataServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_create_rag_corpus_rest_bad_request( + request_type=vertex_rag_data_service.CreateRagCorpusRequest, +): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.create_rag_corpus(request) + + +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.CreateRagCorpusRequest, + dict, + ], +) +def test_create_rag_corpus_rest_call_success(request_type): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["rag_corpus"] = { + "vector_db_config": { + "rag_managed_db": { + "knn": {}, + "ann": {"tree_depth": 1060, "leaf_count": 1056}, + }, + "pinecone": {"index_name": "index_name_value"}, + "vertex_vector_search": { + "index_endpoint": "index_endpoint_value", + "index": "index_value", + }, + "api_auth": { + "api_key_config": { + "api_key_secret_version": "api_key_secret_version_value" + } + }, + "rag_embedding_model_config": { + "vertex_prediction_endpoint": { + "endpoint": "endpoint_value", + "model": "model_value", + "model_version_id": "model_version_id_value", + } + }, + }, + "vertex_ai_search_config": {"serving_config": "serving_config_value"}, + "name": "name_value", + "display_name": "display_name_value", + "description": "description_value", + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "corpus_status": {"state": 1, "error_status": "error_status_value"}, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = vertex_rag_data_service.CreateRagCorpusRequest.meta.fields[ + "rag_corpus" + ] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["rag_corpus"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["rag_corpus"][field])): + del request_init["rag_corpus"][field][i][subfield] + else: + del request_init["rag_corpus"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.create_rag_corpus(request) + + # Establish that the response is the type that we expect. + json_return_value = json_format.MessageToJson(return_value) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_rag_corpus_rest_interceptors(null_interceptor): + transport = transports.VertexRagDataServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.VertexRagDataServiceRestInterceptor(), + ) + client = VertexRagDataServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "post_create_rag_corpus" + ) as post, mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, + "post_create_rag_corpus_with_metadata", + ) as post_with_metadata, mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "pre_create_rag_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = vertex_rag_data_service.CreateRagCorpusRequest.pb( + vertex_rag_data_service.CreateRagCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } - # Establish that the underlying stub method was called. - call.assert_called() - _, args, _ = call.mock_calls[0] - request_msg = vertex_rag_data_service.DeleteRagFileRequest() + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = json_format.MessageToJson(operations_pb2.Operation()) + req.return_value.content = return_value - assert args[0] == request_msg + request = vertex_rag_data_service.CreateRagCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata + client.create_rag_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) -def test_transport_kind_rest(): - transport = VertexRagDataServiceClient.get_transport_class("rest")( - credentials=ga_credentials.AnonymousCredentials() - ) - assert transport.kind == "rest" + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() -def test_create_rag_corpus_rest_bad_request( - request_type=vertex_rag_data_service.CreateRagCorpusRequest, +def test_update_rag_corpus_rest_bad_request( + request_type=vertex_rag_data_service.UpdateRagCorpusRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -7760,23 +9141,25 @@ def test_create_rag_corpus_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.create_rag_corpus(request) + client.update_rag_corpus(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.CreateRagCorpusRequest, + vertex_rag_data_service.UpdateRagCorpusRequest, dict, ], ) -def test_create_rag_corpus_rest_call_success(request_type): +def test_update_rag_corpus_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + } request_init["rag_corpus"] = { "vector_db_config": { "rag_managed_db": { @@ -7802,7 +9185,7 @@ def test_create_rag_corpus_rest_call_success(request_type): }, }, "vertex_ai_search_config": {"serving_config": "serving_config_value"}, - "name": "name_value", + "name": "projects/sample1/locations/sample2/ragCorpora/sample3", "display_name": "display_name_value", "description": "description_value", "create_time": {"seconds": 751, "nanos": 543}, @@ -7815,7 +9198,7 @@ def test_create_rag_corpus_rest_call_success(request_type): # See https://github.com/googleapis/gapic-generator-python/issues/1748 # Determine if the message type is proto-plus or protobuf - test_field = vertex_rag_data_service.CreateRagCorpusRequest.meta.fields[ + test_field = vertex_rag_data_service.UpdateRagCorpusRequest.meta.fields[ "rag_corpus" ] @@ -7893,14 +9276,14 @@ def get_message_fields(field): response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.create_rag_corpus(request) + response = client.update_rag_corpus(request) # Establish that the response is the type that we expect. json_return_value = json_format.MessageToJson(return_value) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_create_rag_corpus_rest_interceptors(null_interceptor): +def test_update_rag_corpus_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -7916,18 +9299,18 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_create_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "post_update_rag_corpus" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_create_rag_corpus_with_metadata", + "post_update_rag_corpus_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_create_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "pre_update_rag_corpus" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.CreateRagCorpusRequest.pb( - vertex_rag_data_service.CreateRagCorpusRequest() + pb_message = vertex_rag_data_service.UpdateRagCorpusRequest.pb( + vertex_rag_data_service.UpdateRagCorpusRequest() ) transcode.return_value = { "method": "post", @@ -7942,7 +9325,7 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor): return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.content = return_value - request = vertex_rag_data_service.CreateRagCorpusRequest() + request = vertex_rag_data_service.UpdateRagCorpusRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -7951,7 +9334,7 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor): post.return_value = operations_pb2.Operation() post_with_metadata.return_value = operations_pb2.Operation(), metadata - client.create_rag_corpus( + client.update_rag_corpus( request, metadata=[ ("key", "val"), @@ -7964,16 +9347,14 @@ def test_create_rag_corpus_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_update_rag_corpus_rest_bad_request( - request_type=vertex_rag_data_service.UpdateRagCorpusRequest, +def test_get_rag_corpus_rest_bad_request( + request_type=vertex_rag_data_service.GetRagCorpusRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = { - "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} - } + request_init = {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -7988,149 +9369,183 @@ def test_update_rag_corpus_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.update_rag_corpus(request) + client.get_rag_corpus(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.UpdateRagCorpusRequest, + vertex_rag_data_service.GetRagCorpusRequest, dict, ], ) -def test_update_rag_corpus_rest_call_success(request_type): +def test_get_rag_corpus_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = { - "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} - } - request_init["rag_corpus"] = { - "vector_db_config": { - "rag_managed_db": { - "knn": {}, - "ann": {"tree_depth": 1060, "leaf_count": 1056}, - }, - "pinecone": {"index_name": "index_name_value"}, - "vertex_vector_search": { - "index_endpoint": "index_endpoint_value", - "index": "index_value", - }, - "api_auth": { - "api_key_config": { - "api_key_secret_version": "api_key_secret_version_value" - } - }, - "rag_embedding_model_config": { - "vertex_prediction_endpoint": { - "endpoint": "endpoint_value", - "model": "model_value", - "model_version_id": "model_version_id_value", - } - }, - }, - "vertex_ai_search_config": {"serving_config": "serving_config_value"}, - "name": "projects/sample1/locations/sample2/ragCorpora/sample3", - "display_name": "display_name_value", - "description": "description_value", - "create_time": {"seconds": 751, "nanos": 543}, - "update_time": {}, - "corpus_status": {"state": 1, "error_status": "error_status_value"}, - "encryption_spec": {"kms_key_name": "kms_key_name_value"}, - } - # The version of a generated dependency at test runtime may differ from the version used during generation. - # Delete any fields which are not present in the current runtime dependency - # See https://github.com/googleapis/gapic-generator-python/issues/1748 + request_init = {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request = request_type(**request_init) - # Determine if the message type is proto-plus or protobuf - test_field = vertex_rag_data_service.UpdateRagCorpusRequest.meta.fields[ - "rag_corpus" - ] + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = vertex_rag_data.RagCorpus( + name="name_value", + display_name="display_name_value", + description="description_value", + ) - def get_message_fields(field): - # Given a field which is a message (composite type), return a list with - # all the fields of the message. - # If the field is not a composite type, return an empty list. - message_fields = [] + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 - if hasattr(field, "message") and field.message: - is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + # Convert return value to protobuf type + return_value = vertex_rag_data.RagCorpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_rag_corpus(request) - if is_field_type_proto_plus_type: - message_fields = field.message.meta.fields.values() - # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types - else: # pragma: NO COVER - message_fields = field.message.DESCRIPTOR.fields - return message_fields + # Establish that the response is the type that we expect. + assert isinstance(response, vertex_rag_data.RagCorpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" - runtime_nested_fields = [ - (field.name, nested_field.name) - for field in get_message_fields(test_field) - for nested_field in get_message_fields(field) - ] - subfields_not_in_runtime = [] +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_rag_corpus_rest_interceptors(null_interceptor): + transport = transports.VertexRagDataServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.VertexRagDataServiceRestInterceptor(), + ) + client = VertexRagDataServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "post_get_rag_corpus" + ) as post, mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, + "post_get_rag_corpus_with_metadata", + ) as post_with_metadata, mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "pre_get_rag_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = vertex_rag_data_service.GetRagCorpusRequest.pb( + vertex_rag_data_service.GetRagCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = vertex_rag_data.RagCorpus.to_json(vertex_rag_data.RagCorpus()) + req.return_value.content = return_value + + request = vertex_rag_data_service.GetRagCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = vertex_rag_data.RagCorpus() + post_with_metadata.return_value = vertex_rag_data.RagCorpus(), metadata + + client.get_rag_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() + + +def test_list_rag_corpora_rest_bad_request( + request_type=vertex_rag_data_service.ListRagCorporaRequest, +): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_rag_corpora(request) - # For each item in the sample request, create a list of sub fields which are not present at runtime - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for field, value in request_init["rag_corpus"].items(): # pragma: NO COVER - result = None - is_repeated = False - # For repeated fields - if isinstance(value, list) and len(value): - is_repeated = True - result = value[0] - # For fields where the type is another message - if isinstance(value, dict): - result = value - if result and hasattr(result, "keys"): - for subfield in result.keys(): - if (field, subfield) not in runtime_nested_fields: - subfields_not_in_runtime.append( - { - "field": field, - "subfield": subfield, - "is_repeated": is_repeated, - } - ) +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.ListRagCorporaRequest, + dict, + ], +) +def test_list_rag_corpora_rest_call_success(request_type): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) - # Remove fields from the sample request which are not present in the runtime version of the dependency - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER - field = subfield_to_delete.get("field") - field_repeated = subfield_to_delete.get("is_repeated") - subfield = subfield_to_delete.get("subfield") - if subfield: - if field_repeated: - for i in range(0, len(request_init["rag_corpus"][field])): - del request_init["rag_corpus"][field][i][subfield] - else: - del request_init["rag_corpus"][field][subfield] + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data_service.ListRagCorporaResponse( + next_page_token="next_page_token_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data_service.ListRagCorporaResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.update_rag_corpus(request) + response = client.list_rag_corpora(request) # Establish that the response is the type that we expect. - json_return_value = json_format.MessageToJson(return_value) + assert isinstance(response, pagers.ListRagCorporaPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_update_rag_corpus_rest_interceptors(null_interceptor): +def test_list_rag_corpora_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8144,20 +9559,18 @@ def test_update_rag_corpus_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_update_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "post_list_rag_corpora" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_update_rag_corpus_with_metadata", + "post_list_rag_corpora_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_update_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "pre_list_rag_corpora" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.UpdateRagCorpusRequest.pb( - vertex_rag_data_service.UpdateRagCorpusRequest() + pb_message = vertex_rag_data_service.ListRagCorporaRequest.pb( + vertex_rag_data_service.ListRagCorporaRequest() ) transcode.return_value = { "method": "post", @@ -8169,19 +9582,24 @@ def test_update_rag_corpus_rest_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = json_format.MessageToJson(operations_pb2.Operation()) + return_value = vertex_rag_data_service.ListRagCorporaResponse.to_json( + vertex_rag_data_service.ListRagCorporaResponse() + ) req.return_value.content = return_value - request = vertex_rag_data_service.UpdateRagCorpusRequest() + request = vertex_rag_data_service.ListRagCorporaRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() - post_with_metadata.return_value = operations_pb2.Operation(), metadata + post.return_value = vertex_rag_data_service.ListRagCorporaResponse() + post_with_metadata.return_value = ( + vertex_rag_data_service.ListRagCorporaResponse(), + metadata, + ) - client.update_rag_corpus( + client.list_rag_corpora( request, metadata=[ ("key", "val"), @@ -8194,8 +9612,8 @@ def test_update_rag_corpus_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_get_rag_corpus_rest_bad_request( - request_type=vertex_rag_data_service.GetRagCorpusRequest, +def test_delete_rag_corpus_rest_bad_request( + request_type=vertex_rag_data_service.DeleteRagCorpusRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" @@ -8216,17 +9634,17 @@ def test_get_rag_corpus_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.get_rag_corpus(request) + client.delete_rag_corpus(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.GetRagCorpusRequest, + vertex_rag_data_service.DeleteRagCorpusRequest, dict, ], ) -def test_get_rag_corpus_rest_call_success(request_type): +def test_delete_rag_corpus_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -8238,33 +9656,23 @@ def test_get_rag_corpus_rest_call_success(request_type): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data.RagCorpus( - name="name_value", - display_name="display_name_value", - description="description_value", - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data.RagCorpus.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.get_rag_corpus(request) + response = client.delete_rag_corpus(request) # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data.RagCorpus) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" + json_return_value = json_format.MessageToJson(return_value) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_get_rag_corpus_rest_interceptors(null_interceptor): +def test_delete_rag_corpus_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8278,18 +9686,20 @@ def test_get_rag_corpus_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_get_rag_corpus" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "post_delete_rag_corpus" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_get_rag_corpus_with_metadata", + "post_delete_rag_corpus_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_get_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "pre_delete_rag_corpus" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.GetRagCorpusRequest.pb( - vertex_rag_data_service.GetRagCorpusRequest() + pb_message = vertex_rag_data_service.DeleteRagCorpusRequest.pb( + vertex_rag_data_service.DeleteRagCorpusRequest() ) transcode.return_value = { "method": "post", @@ -8301,19 +9711,19 @@ def test_get_rag_corpus_rest_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data.RagCorpus.to_json(vertex_rag_data.RagCorpus()) + return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.content = return_value - request = vertex_rag_data_service.GetRagCorpusRequest() + request = vertex_rag_data_service.DeleteRagCorpusRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data.RagCorpus() - post_with_metadata.return_value = vertex_rag_data.RagCorpus(), metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata - client.get_rag_corpus( + client.delete_rag_corpus( request, metadata=[ ("key", "val"), @@ -8326,14 +9736,14 @@ def test_get_rag_corpus_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_list_rag_corpora_rest_bad_request( - request_type=vertex_rag_data_service.ListRagCorporaRequest, +def test_upload_rag_file_rest_bad_request( + request_type=vertex_rag_data_service.UploadRagFileRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8348,51 +9758,48 @@ def test_list_rag_corpora_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.list_rag_corpora(request) + client.upload_rag_file(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ListRagCorporaRequest, + vertex_rag_data_service.UploadRagFileRequest, dict, ], ) -def test_list_rag_corpora_rest_call_success(request_type): +def test_upload_rag_file_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.ListRagCorporaResponse( - next_page_token="next_page_token_value", - ) + return_value = vertex_rag_data_service.UploadRagFileResponse() # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 # Convert return value to protobuf type - return_value = vertex_rag_data_service.ListRagCorporaResponse.pb(return_value) + return_value = vertex_rag_data_service.UploadRagFileResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.list_rag_corpora(request) + response = client.upload_rag_file(request) # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListRagCorporaPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_list_rag_corpora_rest_interceptors(null_interceptor): +def test_upload_rag_file_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8406,18 +9813,18 @@ def test_list_rag_corpora_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_list_rag_corpora" + transports.VertexRagDataServiceRestInterceptor, "post_upload_rag_file" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_list_rag_corpora_with_metadata", + "post_upload_rag_file_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_list_rag_corpora" + transports.VertexRagDataServiceRestInterceptor, "pre_upload_rag_file" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.ListRagCorporaRequest.pb( - vertex_rag_data_service.ListRagCorporaRequest() + pb_message = vertex_rag_data_service.UploadRagFileRequest.pb( + vertex_rag_data_service.UploadRagFileRequest() ) transcode.return_value = { "method": "post", @@ -8429,24 +9836,24 @@ def test_list_rag_corpora_rest_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data_service.ListRagCorporaResponse.to_json( - vertex_rag_data_service.ListRagCorporaResponse() + return_value = vertex_rag_data_service.UploadRagFileResponse.to_json( + vertex_rag_data_service.UploadRagFileResponse() ) req.return_value.content = return_value - request = vertex_rag_data_service.ListRagCorporaRequest() + request = vertex_rag_data_service.UploadRagFileRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data_service.ListRagCorporaResponse() + post.return_value = vertex_rag_data_service.UploadRagFileResponse() post_with_metadata.return_value = ( - vertex_rag_data_service.ListRagCorporaResponse(), + vertex_rag_data_service.UploadRagFileResponse(), metadata, ) - client.list_rag_corpora( + client.upload_rag_file( request, metadata=[ ("key", "val"), @@ -8459,14 +9866,14 @@ def test_list_rag_corpora_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_delete_rag_corpus_rest_bad_request( - request_type=vertex_rag_data_service.DeleteRagCorpusRequest, +def test_import_rag_files_rest_bad_request( + request_type=vertex_rag_data_service.ImportRagFilesRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8481,23 +9888,23 @@ def test_delete_rag_corpus_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.delete_rag_corpus(request) + client.import_rag_files(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.DeleteRagCorpusRequest, + vertex_rag_data_service.ImportRagFilesRequest, dict, ], ) -def test_delete_rag_corpus_rest_call_success(request_type): +def test_import_rag_files_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -8512,14 +9919,14 @@ def test_delete_rag_corpus_rest_call_success(request_type): response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.delete_rag_corpus(request) + response = client.import_rag_files(request) # Establish that the response is the type that we expect. json_return_value = json_format.MessageToJson(return_value) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_delete_rag_corpus_rest_interceptors(null_interceptor): +def test_import_rag_files_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8535,18 +9942,18 @@ def test_delete_rag_corpus_rest_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_delete_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "post_import_rag_files" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_delete_rag_corpus_with_metadata", + "post_import_rag_files_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_delete_rag_corpus" + transports.VertexRagDataServiceRestInterceptor, "pre_import_rag_files" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.DeleteRagCorpusRequest.pb( - vertex_rag_data_service.DeleteRagCorpusRequest() + pb_message = vertex_rag_data_service.ImportRagFilesRequest.pb( + vertex_rag_data_service.ImportRagFilesRequest() ) transcode.return_value = { "method": "post", @@ -8561,7 +9968,7 @@ def test_delete_rag_corpus_rest_interceptors(null_interceptor): return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.content = return_value - request = vertex_rag_data_service.DeleteRagCorpusRequest() + request = vertex_rag_data_service.ImportRagFilesRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -8570,7 +9977,7 @@ def test_delete_rag_corpus_rest_interceptors(null_interceptor): post.return_value = operations_pb2.Operation() post_with_metadata.return_value = operations_pb2.Operation(), metadata - client.delete_rag_corpus( + client.import_rag_files( request, metadata=[ ("key", "val"), @@ -8583,14 +9990,16 @@ def test_delete_rag_corpus_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_upload_rag_file_rest_bad_request( - request_type=vertex_rag_data_service.UploadRagFileRequest, +def test_get_rag_file_rest_bad_request( + request_type=vertex_rag_data_service.GetRagFileRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8605,48 +10014,57 @@ def test_upload_rag_file_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.upload_rag_file(request) + client.get_rag_file(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.UploadRagFileRequest, + vertex_rag_data_service.GetRagFileRequest, dict, ], ) -def test_upload_rag_file_rest_call_success(request_type): +def test_get_rag_file_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.UploadRagFileResponse() + return_value = vertex_rag_data.RagFile( + name="name_value", + display_name="display_name_value", + description="description_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 # Convert return value to protobuf type - return_value = vertex_rag_data_service.UploadRagFileResponse.pb(return_value) + return_value = vertex_rag_data.RagFile.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.upload_rag_file(request) + response = client.get_rag_file(request) # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) + assert isinstance(response, vertex_rag_data.RagFile) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_upload_rag_file_rest_interceptors(null_interceptor): +def test_get_rag_file_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8660,18 +10078,18 @@ def test_upload_rag_file_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_upload_rag_file" + transports.VertexRagDataServiceRestInterceptor, "post_get_rag_file" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_upload_rag_file_with_metadata", + "post_get_rag_file_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_upload_rag_file" + transports.VertexRagDataServiceRestInterceptor, "pre_get_rag_file" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.UploadRagFileRequest.pb( - vertex_rag_data_service.UploadRagFileRequest() + pb_message = vertex_rag_data_service.GetRagFileRequest.pb( + vertex_rag_data_service.GetRagFileRequest() ) transcode.return_value = { "method": "post", @@ -8683,24 +10101,19 @@ def test_upload_rag_file_rest_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data_service.UploadRagFileResponse.to_json( - vertex_rag_data_service.UploadRagFileResponse() - ) + return_value = vertex_rag_data.RagFile.to_json(vertex_rag_data.RagFile()) req.return_value.content = return_value - request = vertex_rag_data_service.UploadRagFileRequest() + request = vertex_rag_data_service.GetRagFileRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data_service.UploadRagFileResponse() - post_with_metadata.return_value = ( - vertex_rag_data_service.UploadRagFileResponse(), - metadata, - ) + post.return_value = vertex_rag_data.RagFile() + post_with_metadata.return_value = vertex_rag_data.RagFile(), metadata - client.upload_rag_file( + client.get_rag_file( request, metadata=[ ("key", "val"), @@ -8713,8 +10126,8 @@ def test_upload_rag_file_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_import_rag_files_rest_bad_request( - request_type=vertex_rag_data_service.ImportRagFilesRequest, +def test_list_rag_files_rest_bad_request( + request_type=vertex_rag_data_service.ListRagFilesRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" @@ -8735,17 +10148,17 @@ def test_import_rag_files_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.import_rag_files(request) + client.list_rag_files(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ImportRagFilesRequest, + vertex_rag_data_service.ListRagFilesRequest, dict, ], ) -def test_import_rag_files_rest_call_success(request_type): +def test_list_rag_files_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -8757,23 +10170,29 @@ def test_import_rag_files_rest_call_success(request_type): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data_service.ListRagFilesResponse( + next_page_token="next_page_token_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data_service.ListRagFilesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.import_rag_files(request) + response = client.list_rag_files(request) # Establish that the response is the type that we expect. - json_return_value = json_format.MessageToJson(return_value) + assert isinstance(response, pagers.ListRagFilesPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_import_rag_files_rest_interceptors(null_interceptor): +def test_list_rag_files_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8787,20 +10206,18 @@ def test_import_rag_files_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_import_rag_files" + transports.VertexRagDataServiceRestInterceptor, "post_list_rag_files" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_import_rag_files_with_metadata", + "post_list_rag_files_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_import_rag_files" + transports.VertexRagDataServiceRestInterceptor, "pre_list_rag_files" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.ImportRagFilesRequest.pb( - vertex_rag_data_service.ImportRagFilesRequest() + pb_message = vertex_rag_data_service.ListRagFilesRequest.pb( + vertex_rag_data_service.ListRagFilesRequest() ) transcode.return_value = { "method": "post", @@ -8812,19 +10229,24 @@ def test_import_rag_files_rest_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = json_format.MessageToJson(operations_pb2.Operation()) + return_value = vertex_rag_data_service.ListRagFilesResponse.to_json( + vertex_rag_data_service.ListRagFilesResponse() + ) req.return_value.content = return_value - request = vertex_rag_data_service.ImportRagFilesRequest() + request = vertex_rag_data_service.ListRagFilesRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() - post_with_metadata.return_value = operations_pb2.Operation(), metadata + post.return_value = vertex_rag_data_service.ListRagFilesResponse() + post_with_metadata.return_value = ( + vertex_rag_data_service.ListRagFilesResponse(), + metadata, + ) - client.import_rag_files( + client.list_rag_files( request, metadata=[ ("key", "val"), @@ -8837,8 +10259,8 @@ def test_import_rag_files_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_get_rag_file_rest_bad_request( - request_type=vertex_rag_data_service.GetRagFileRequest, +def test_delete_rag_file_rest_bad_request( + request_type=vertex_rag_data_service.DeleteRagFileRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" @@ -8861,17 +10283,17 @@ def test_get_rag_file_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.get_rag_file(request) + client.delete_rag_file(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.GetRagFileRequest, + vertex_rag_data_service.DeleteRagFileRequest, dict, ], ) -def test_get_rag_file_rest_call_success(request_type): +def test_delete_rag_file_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -8885,33 +10307,23 @@ def test_get_rag_file_rest_call_success(request_type): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data.RagFile( - name="name_value", - display_name="display_name_value", - description="description_value", - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data.RagFile.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.get_rag_file(request) + response = client.delete_rag_file(request) # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data.RagFile) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" + json_return_value = json_format.MessageToJson(return_value) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_get_rag_file_rest_interceptors(null_interceptor): +def test_delete_rag_file_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -8925,18 +10337,20 @@ def test_get_rag_file_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_get_rag_file" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "post_delete_rag_file" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_get_rag_file_with_metadata", + "post_delete_rag_file_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_get_rag_file" + transports.VertexRagDataServiceRestInterceptor, "pre_delete_rag_file" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.GetRagFileRequest.pb( - vertex_rag_data_service.GetRagFileRequest() + pb_message = vertex_rag_data_service.DeleteRagFileRequest.pb( + vertex_rag_data_service.DeleteRagFileRequest() ) transcode.return_value = { "method": "post", @@ -8948,19 +10362,19 @@ def test_get_rag_file_rest_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data.RagFile.to_json(vertex_rag_data.RagFile()) + return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.content = return_value - request = vertex_rag_data_service.GetRagFileRequest() + request = vertex_rag_data_service.DeleteRagFileRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data.RagFile() - post_with_metadata.return_value = vertex_rag_data.RagFile(), metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata - client.get_rag_file( + client.delete_rag_file( request, metadata=[ ("key", "val"), @@ -8973,14 +10387,18 @@ def test_get_rag_file_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_list_rag_files_rest_bad_request( - request_type=vertex_rag_data_service.ListRagFilesRequest, +def test_update_rag_engine_config_rest_bad_request( + request_type=vertex_rag_data_service.UpdateRagEngineConfigRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = { + "rag_engine_config": { + "name": "projects/sample1/locations/sample2/ragEngineConfig" + } + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8995,51 +10413,122 @@ def test_list_rag_files_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.list_rag_files(request) + client.update_rag_engine_config(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ListRagFilesRequest, + vertex_rag_data_service.UpdateRagEngineConfigRequest, dict, ], ) -def test_list_rag_files_rest_call_success(request_type): +def test_update_rag_engine_config_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = { + "rag_engine_config": { + "name": "projects/sample1/locations/sample2/ragEngineConfig" + } + } + request_init["rag_engine_config"] = { + "name": "projects/sample1/locations/sample2/ragEngineConfig", + "rag_managed_db_config": {"scaled": {}, "basic": {}, "unprovisioned": {}}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = vertex_rag_data_service.UpdateRagEngineConfigRequest.meta.fields[ + "rag_engine_config" + ] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["rag_engine_config"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["rag_engine_config"][field])): + del request_init["rag_engine_config"][field][i][subfield] + else: + del request_init["rag_engine_config"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.ListRagFilesResponse( - next_page_token="next_page_token_value", - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data_service.ListRagFilesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.list_rag_files(request) + response = client.update_rag_engine_config(request) # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListRagFilesPager) - assert response.next_page_token == "next_page_token_value" + json_return_value = json_format.MessageToJson(return_value) @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_list_rag_files_rest_interceptors(null_interceptor): +def test_update_rag_engine_config_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -9053,18 +10542,20 @@ def test_list_rag_files_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_list_rag_files" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.VertexRagDataServiceRestInterceptor, "post_update_rag_engine_config" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_list_rag_files_with_metadata", + "post_update_rag_engine_config_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_list_rag_files" + transports.VertexRagDataServiceRestInterceptor, "pre_update_rag_engine_config" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.ListRagFilesRequest.pb( - vertex_rag_data_service.ListRagFilesRequest() + pb_message = vertex_rag_data_service.UpdateRagEngineConfigRequest.pb( + vertex_rag_data_service.UpdateRagEngineConfigRequest() ) transcode.return_value = { "method": "post", @@ -9076,24 +10567,19 @@ def test_list_rag_files_rest_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data_service.ListRagFilesResponse.to_json( - vertex_rag_data_service.ListRagFilesResponse() - ) + return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.content = return_value - request = vertex_rag_data_service.ListRagFilesRequest() + request = vertex_rag_data_service.UpdateRagEngineConfigRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data_service.ListRagFilesResponse() - post_with_metadata.return_value = ( - vertex_rag_data_service.ListRagFilesResponse(), - metadata, - ) + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata - client.list_rag_files( + client.update_rag_engine_config( request, metadata=[ ("key", "val"), @@ -9106,16 +10592,14 @@ def test_list_rag_files_rest_interceptors(null_interceptor): post_with_metadata.assert_called_once() -def test_delete_rag_file_rest_bad_request( - request_type=vertex_rag_data_service.DeleteRagFileRequest, +def test_get_rag_engine_config_rest_bad_request( + request_type=vertex_rag_data_service.GetRagEngineConfigRequest, ): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" - } + request_init = {"name": "projects/sample1/locations/sample2/ragEngineConfig"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -9130,47 +10614,51 @@ def test_delete_rag_file_rest_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - client.delete_rag_file(request) + client.get_rag_engine_config(request) @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.DeleteRagFileRequest, + vertex_rag_data_service.GetRagEngineConfigRequest, dict, ], ) -def test_delete_rag_file_rest_call_success(request_type): +def test_get_rag_engine_config_rest_call_success(request_type): client = VertexRagDataServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" - } + request_init = {"name": "projects/sample1/locations/sample2/ragEngineConfig"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data.RagEngineConfig( + name="name_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data.RagEngineConfig.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.content = json_return_value.encode("UTF-8") req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = client.delete_rag_file(request) + response = client.get_rag_engine_config(request) # Establish that the response is the type that we expect. - json_return_value = json_format.MessageToJson(return_value) + assert isinstance(response, vertex_rag_data.RagEngineConfig) + assert response.name == "name_value" @pytest.mark.parametrize("null_interceptor", [True, False]) -def test_delete_rag_file_rest_interceptors(null_interceptor): +def test_get_rag_engine_config_rest_interceptors(null_interceptor): transport = transports.VertexRagDataServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), interceptor=None @@ -9184,20 +10672,18 @@ def test_delete_rag_file_rest_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "post_delete_rag_file" + transports.VertexRagDataServiceRestInterceptor, "post_get_rag_engine_config" ) as post, mock.patch.object( transports.VertexRagDataServiceRestInterceptor, - "post_delete_rag_file_with_metadata", + "post_get_rag_engine_config_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.VertexRagDataServiceRestInterceptor, "pre_delete_rag_file" + transports.VertexRagDataServiceRestInterceptor, "pre_get_rag_engine_config" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.DeleteRagFileRequest.pb( - vertex_rag_data_service.DeleteRagFileRequest() + pb_message = vertex_rag_data_service.GetRagEngineConfigRequest.pb( + vertex_rag_data_service.GetRagEngineConfigRequest() ) transcode.return_value = { "method": "post", @@ -9209,19 +10695,21 @@ def test_delete_rag_file_rest_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = json_format.MessageToJson(operations_pb2.Operation()) + return_value = vertex_rag_data.RagEngineConfig.to_json( + vertex_rag_data.RagEngineConfig() + ) req.return_value.content = return_value - request = vertex_rag_data_service.DeleteRagFileRequest() + request = vertex_rag_data_service.GetRagEngineConfigRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() - post_with_metadata.return_value = operations_pb2.Operation(), metadata + post.return_value = vertex_rag_data.RagEngineConfig() + post_with_metadata.return_value = vertex_rag_data.RagEngineConfig(), metadata - client.delete_rag_file( + client.get_rag_engine_config( request, metadata=[ ("key", "val"), @@ -10069,40 +11557,326 @@ def test_delete_rag_file_empty_call_rest(): _, args, _ = call.mock_calls[0] request_msg = vertex_rag_data_service.DeleteRagFileRequest() - assert args[0] == request_msg + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_rag_engine_config_empty_call_rest(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + client.update_rag_engine_config(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.UpdateRagEngineConfigRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_rag_engine_config_empty_call_rest(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + client.get_rag_engine_config(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.GetRagEngineConfigRequest() + + assert args[0] == request_msg + + +def test_vertex_rag_data_service_rest_lro_client(): + client = VertexRagDataServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + transport = client.transport + + # Ensure that we have an api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_transport_kind_rest_asyncio(): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + transport = VertexRagDataServiceAsyncClient.get_transport_class("rest_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "rest_asyncio" + + +@pytest.mark.asyncio +async def test_create_rag_corpus_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.CreateRagCorpusRequest, +): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="rest_asyncio" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(AsyncAuthorizedSession, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.read = mock.AsyncMock(return_value=b"{}") + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + await client.create_rag_corpus(request) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.CreateRagCorpusRequest, + dict, + ], +) +async def test_create_rag_corpus_rest_asyncio_call_success(request_type): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="rest_asyncio" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/locations/sample2"} + request_init["rag_corpus"] = { + "vector_db_config": { + "rag_managed_db": { + "knn": {}, + "ann": {"tree_depth": 1060, "leaf_count": 1056}, + }, + "pinecone": {"index_name": "index_name_value"}, + "vertex_vector_search": { + "index_endpoint": "index_endpoint_value", + "index": "index_value", + }, + "api_auth": { + "api_key_config": { + "api_key_secret_version": "api_key_secret_version_value" + } + }, + "rag_embedding_model_config": { + "vertex_prediction_endpoint": { + "endpoint": "endpoint_value", + "model": "model_value", + "model_version_id": "model_version_id_value", + } + }, + }, + "vertex_ai_search_config": {"serving_config": "serving_config_value"}, + "name": "name_value", + "display_name": "display_name_value", + "description": "description_value", + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "corpus_status": {"state": 1, "error_status": "error_status_value"}, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = vertex_rag_data_service.CreateRagCorpusRequest.meta.fields[ + "rag_corpus" + ] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["rag_corpus"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["rag_corpus"][field])): + del request_init["rag_corpus"][field][i][subfield] + else: + del request_init["rag_corpus"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.read = mock.AsyncMock( + return_value=json_return_value.encode("UTF-8") + ) + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = await client.create_rag_corpus(request) + + # Establish that the response is the type that we expect. + json_return_value = json_format.MessageToJson(return_value) -def test_vertex_rag_data_service_rest_lro_client(): - client = VertexRagDataServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", +@pytest.mark.asyncio +@pytest.mark.parametrize("null_interceptor", [True, False]) +async def test_create_rag_corpus_rest_asyncio_interceptors(null_interceptor): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + transport = transports.AsyncVertexRagDataServiceRestTransport( + credentials=async_anonymous_credentials(), + interceptor=None + if null_interceptor + else transports.AsyncVertexRagDataServiceRestInterceptor(), ) - transport = client.transport + client = VertexRagDataServiceAsyncClient(transport=transport) - # Ensure that we have an api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.AbstractOperationsClient, - ) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, "post_create_rag_corpus" + ) as post, mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, + "post_create_rag_corpus_with_metadata", + ) as post_with_metadata, mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_create_rag_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = vertex_rag_data_service.CreateRagCorpusRequest.pb( + vertex_rag_data_service.CreateRagCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = json_format.MessageToJson(operations_pb2.Operation()) + req.return_value.read = mock.AsyncMock(return_value=return_value) + request = vertex_rag_data_service.CreateRagCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata -def test_transport_kind_rest_asyncio(): - if not HAS_ASYNC_REST_EXTRA: - pytest.skip( - "the library must be installed with the `async_rest` extra to test this feature." + await client.create_rag_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], ) - transport = VertexRagDataServiceAsyncClient.get_transport_class("rest_asyncio")( - credentials=async_anonymous_credentials() - ) - assert transport.kind == "rest_asyncio" + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() @pytest.mark.asyncio -async def test_create_rag_corpus_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.CreateRagCorpusRequest, +async def test_update_rag_corpus_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.UpdateRagCorpusRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -10112,7 +11886,9 @@ async def test_create_rag_corpus_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -10126,18 +11902,18 @@ async def test_create_rag_corpus_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.create_rag_corpus(request) + await client.update_rag_corpus(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.CreateRagCorpusRequest, + vertex_rag_data_service.UpdateRagCorpusRequest, dict, ], ) -async def test_create_rag_corpus_rest_asyncio_call_success(request_type): +async def test_update_rag_corpus_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10147,7 +11923,9 @@ async def test_create_rag_corpus_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + } request_init["rag_corpus"] = { "vector_db_config": { "rag_managed_db": { @@ -10173,7 +11951,7 @@ async def test_create_rag_corpus_rest_asyncio_call_success(request_type): }, }, "vertex_ai_search_config": {"serving_config": "serving_config_value"}, - "name": "name_value", + "name": "projects/sample1/locations/sample2/ragCorpora/sample3", "display_name": "display_name_value", "description": "description_value", "create_time": {"seconds": 751, "nanos": 543}, @@ -10186,7 +11964,7 @@ async def test_create_rag_corpus_rest_asyncio_call_success(request_type): # See https://github.com/googleapis/gapic-generator-python/issues/1748 # Determine if the message type is proto-plus or protobuf - test_field = vertex_rag_data_service.CreateRagCorpusRequest.meta.fields[ + test_field = vertex_rag_data_service.UpdateRagCorpusRequest.meta.fields[ "rag_corpus" ] @@ -10214,67 +11992,217 @@ def get_message_fields(field): subfields_not_in_runtime = [] - # For each item in the sample request, create a list of sub fields which are not present at runtime - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for field, value in request_init["rag_corpus"].items(): # pragma: NO COVER - result = None - is_repeated = False - # For repeated fields - if isinstance(value, list) and len(value): - is_repeated = True - result = value[0] - # For fields where the type is another message - if isinstance(value, dict): - result = value + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["rag_corpus"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["rag_corpus"][field])): + del request_init["rag_corpus"][field][i][subfield] + else: + del request_init["rag_corpus"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.read = mock.AsyncMock( + return_value=json_return_value.encode("UTF-8") + ) + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = await client.update_rag_corpus(request) + + # Establish that the response is the type that we expect. + json_return_value = json_format.MessageToJson(return_value) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("null_interceptor", [True, False]) +async def test_update_rag_corpus_rest_asyncio_interceptors(null_interceptor): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + transport = transports.AsyncVertexRagDataServiceRestTransport( + credentials=async_anonymous_credentials(), + interceptor=None + if null_interceptor + else transports.AsyncVertexRagDataServiceRestInterceptor(), + ) + client = VertexRagDataServiceAsyncClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, "post_update_rag_corpus" + ) as post, mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, + "post_update_rag_corpus_with_metadata", + ) as post_with_metadata, mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_update_rag_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + post_with_metadata.assert_not_called() + pb_message = vertex_rag_data_service.UpdateRagCorpusRequest.pb( + vertex_rag_data_service.UpdateRagCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = json_format.MessageToJson(operations_pb2.Operation()) + req.return_value.read = mock.AsyncMock(return_value=return_value) + + request = vertex_rag_data_service.UpdateRagCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata + + await client.update_rag_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + post_with_metadata.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_rag_corpus_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.GetRagCorpusRequest, +): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="rest_asyncio" + ) + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(AsyncAuthorizedSession, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.read = mock.AsyncMock(return_value=b"{}") + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + await client.get_rag_corpus(request) + - if result and hasattr(result, "keys"): - for subfield in result.keys(): - if (field, subfield) not in runtime_nested_fields: - subfields_not_in_runtime.append( - { - "field": field, - "subfield": subfield, - "is_repeated": is_repeated, - } - ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "request_type", + [ + vertex_rag_data_service.GetRagCorpusRequest, + dict, + ], +) +async def test_get_rag_corpus_rest_asyncio_call_success(request_type): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="rest_asyncio" + ) - # Remove fields from the sample request which are not present in the runtime version of the dependency - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER - field = subfield_to_delete.get("field") - field_repeated = subfield_to_delete.get("is_repeated") - subfield = subfield_to_delete.get("subfield") - if subfield: - if field_repeated: - for i in range(0, len(request_init["rag_corpus"][field])): - del request_init["rag_corpus"][field][i][subfield] - else: - del request_init["rag_corpus"][field][subfield] + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data.RagCorpus( + name="name_value", + display_name="display_name_value", + description="description_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data.RagCorpus.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.create_rag_corpus(request) + response = await client.get_rag_corpus(request) # Establish that the response is the type that we expect. - json_return_value = json_format.MessageToJson(return_value) + assert isinstance(response, vertex_rag_data.RagCorpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_create_rag_corpus_rest_asyncio_interceptors(null_interceptor): +async def test_get_rag_corpus_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10292,20 +12220,18 @@ async def test_create_rag_corpus_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_create_rag_corpus" + transports.AsyncVertexRagDataServiceRestInterceptor, "post_get_rag_corpus" ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_create_rag_corpus_with_metadata", + "post_get_rag_corpus_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_create_rag_corpus" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_get_rag_corpus" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.CreateRagCorpusRequest.pb( - vertex_rag_data_service.CreateRagCorpusRequest() + pb_message = vertex_rag_data_service.GetRagCorpusRequest.pb( + vertex_rag_data_service.GetRagCorpusRequest() ) transcode.return_value = { "method": "post", @@ -10317,19 +12243,19 @@ async def test_create_rag_corpus_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = json_format.MessageToJson(operations_pb2.Operation()) + return_value = vertex_rag_data.RagCorpus.to_json(vertex_rag_data.RagCorpus()) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.CreateRagCorpusRequest() + request = vertex_rag_data_service.GetRagCorpusRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() - post_with_metadata.return_value = operations_pb2.Operation(), metadata + post.return_value = vertex_rag_data.RagCorpus() + post_with_metadata.return_value = vertex_rag_data.RagCorpus(), metadata - await client.create_rag_corpus( + await client.get_rag_corpus( request, metadata=[ ("key", "val"), @@ -10343,8 +12269,8 @@ async def test_create_rag_corpus_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_update_rag_corpus_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.UpdateRagCorpusRequest, +async def test_list_rag_corpora_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.ListRagCorporaRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -10354,9 +12280,7 @@ async def test_update_rag_corpus_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = { - "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} - } + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -10370,18 +12294,18 @@ async def test_update_rag_corpus_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.update_rag_corpus(request) + await client.list_rag_corpora(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.UpdateRagCorpusRequest, + vertex_rag_data_service.ListRagCorporaRequest, dict, ], ) -async def test_update_rag_corpus_rest_asyncio_call_success(request_type): +async def test_list_rag_corpora_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10391,136 +12315,38 @@ async def test_update_rag_corpus_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = { - "rag_corpus": {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} - } - request_init["rag_corpus"] = { - "vector_db_config": { - "rag_managed_db": { - "knn": {}, - "ann": {"tree_depth": 1060, "leaf_count": 1056}, - }, - "pinecone": {"index_name": "index_name_value"}, - "vertex_vector_search": { - "index_endpoint": "index_endpoint_value", - "index": "index_value", - }, - "api_auth": { - "api_key_config": { - "api_key_secret_version": "api_key_secret_version_value" - } - }, - "rag_embedding_model_config": { - "vertex_prediction_endpoint": { - "endpoint": "endpoint_value", - "model": "model_value", - "model_version_id": "model_version_id_value", - } - }, - }, - "vertex_ai_search_config": {"serving_config": "serving_config_value"}, - "name": "projects/sample1/locations/sample2/ragCorpora/sample3", - "display_name": "display_name_value", - "description": "description_value", - "create_time": {"seconds": 751, "nanos": 543}, - "update_time": {}, - "corpus_status": {"state": 1, "error_status": "error_status_value"}, - "encryption_spec": {"kms_key_name": "kms_key_name_value"}, - } - # The version of a generated dependency at test runtime may differ from the version used during generation. - # Delete any fields which are not present in the current runtime dependency - # See https://github.com/googleapis/gapic-generator-python/issues/1748 - - # Determine if the message type is proto-plus or protobuf - test_field = vertex_rag_data_service.UpdateRagCorpusRequest.meta.fields[ - "rag_corpus" - ] - - def get_message_fields(field): - # Given a field which is a message (composite type), return a list with - # all the fields of the message. - # If the field is not a composite type, return an empty list. - message_fields = [] - - if hasattr(field, "message") and field.message: - is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") - - if is_field_type_proto_plus_type: - message_fields = field.message.meta.fields.values() - # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types - else: # pragma: NO COVER - message_fields = field.message.DESCRIPTOR.fields - return message_fields - - runtime_nested_fields = [ - (field.name, nested_field.name) - for field in get_message_fields(test_field) - for nested_field in get_message_fields(field) - ] - - subfields_not_in_runtime = [] - - # For each item in the sample request, create a list of sub fields which are not present at runtime - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for field, value in request_init["rag_corpus"].items(): # pragma: NO COVER - result = None - is_repeated = False - # For repeated fields - if isinstance(value, list) and len(value): - is_repeated = True - result = value[0] - # For fields where the type is another message - if isinstance(value, dict): - result = value - - if result and hasattr(result, "keys"): - for subfield in result.keys(): - if (field, subfield) not in runtime_nested_fields: - subfields_not_in_runtime.append( - { - "field": field, - "subfield": subfield, - "is_repeated": is_repeated, - } - ) - - # Remove fields from the sample request which are not present in the runtime version of the dependency - # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime - for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER - field = subfield_to_delete.get("field") - field_repeated = subfield_to_delete.get("is_repeated") - subfield = subfield_to_delete.get("subfield") - if subfield: - if field_repeated: - for i in range(0, len(request_init["rag_corpus"][field])): - del request_init["rag_corpus"][field][i][subfield] - else: - del request_init["rag_corpus"][field][subfield] + request_init = {"parent": "projects/sample1/locations/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data_service.ListRagCorporaResponse( + next_page_token="next_page_token_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data_service.ListRagCorporaResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.update_rag_corpus(request) + response = await client.list_rag_corpora(request) # Establish that the response is the type that we expect. - json_return_value = json_format.MessageToJson(return_value) + assert isinstance(response, pagers.ListRagCorporaAsyncPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_update_rag_corpus_rest_asyncio_interceptors(null_interceptor): +async def test_list_rag_corpora_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10538,20 +12364,18 @@ async def test_update_rag_corpus_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_update_rag_corpus" + transports.AsyncVertexRagDataServiceRestInterceptor, "post_list_rag_corpora" ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_update_rag_corpus_with_metadata", + "post_list_rag_corpora_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_update_rag_corpus" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_list_rag_corpora" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.UpdateRagCorpusRequest.pb( - vertex_rag_data_service.UpdateRagCorpusRequest() + pb_message = vertex_rag_data_service.ListRagCorporaRequest.pb( + vertex_rag_data_service.ListRagCorporaRequest() ) transcode.return_value = { "method": "post", @@ -10563,19 +12387,24 @@ async def test_update_rag_corpus_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = json_format.MessageToJson(operations_pb2.Operation()) + return_value = vertex_rag_data_service.ListRagCorporaResponse.to_json( + vertex_rag_data_service.ListRagCorporaResponse() + ) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.UpdateRagCorpusRequest() + request = vertex_rag_data_service.ListRagCorporaRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() - post_with_metadata.return_value = operations_pb2.Operation(), metadata + post.return_value = vertex_rag_data_service.ListRagCorporaResponse() + post_with_metadata.return_value = ( + vertex_rag_data_service.ListRagCorporaResponse(), + metadata, + ) - await client.update_rag_corpus( + await client.list_rag_corpora( request, metadata=[ ("key", "val"), @@ -10589,8 +12418,8 @@ async def test_update_rag_corpus_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_get_rag_corpus_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.GetRagCorpusRequest, +async def test_delete_rag_corpus_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.DeleteRagCorpusRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -10614,18 +12443,18 @@ async def test_get_rag_corpus_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.get_rag_corpus(request) + await client.delete_rag_corpus(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.GetRagCorpusRequest, + vertex_rag_data_service.DeleteRagCorpusRequest, dict, ], ) -async def test_get_rag_corpus_rest_asyncio_call_success(request_type): +async def test_delete_rag_corpus_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10641,36 +12470,26 @@ async def test_get_rag_corpus_rest_asyncio_call_success(request_type): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data.RagCorpus( - name="name_value", - display_name="display_name_value", - description="description_value", - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data.RagCorpus.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.get_rag_corpus(request) + response = await client.delete_rag_corpus(request) # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data.RagCorpus) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" + json_return_value = json_format.MessageToJson(return_value) @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_get_rag_corpus_rest_asyncio_interceptors(null_interceptor): +async def test_delete_rag_corpus_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10688,18 +12507,20 @@ async def test_get_rag_corpus_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_get_rag_corpus" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, "post_delete_rag_corpus" ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_get_rag_corpus_with_metadata", + "post_delete_rag_corpus_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_get_rag_corpus" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_delete_rag_corpus" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.GetRagCorpusRequest.pb( - vertex_rag_data_service.GetRagCorpusRequest() + pb_message = vertex_rag_data_service.DeleteRagCorpusRequest.pb( + vertex_rag_data_service.DeleteRagCorpusRequest() ) transcode.return_value = { "method": "post", @@ -10711,19 +12532,19 @@ async def test_get_rag_corpus_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data.RagCorpus.to_json(vertex_rag_data.RagCorpus()) + return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.GetRagCorpusRequest() + request = vertex_rag_data_service.DeleteRagCorpusRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data.RagCorpus() - post_with_metadata.return_value = vertex_rag_data.RagCorpus(), metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata - await client.get_rag_corpus( + await client.delete_rag_corpus( request, metadata=[ ("key", "val"), @@ -10737,8 +12558,8 @@ async def test_get_rag_corpus_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_list_rag_corpora_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.ListRagCorporaRequest, +async def test_upload_rag_file_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.UploadRagFileRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -10748,7 +12569,7 @@ async def test_list_rag_corpora_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -10762,18 +12583,18 @@ async def test_list_rag_corpora_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.list_rag_corpora(request) + await client.upload_rag_file(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ListRagCorporaRequest, + vertex_rag_data_service.UploadRagFileRequest, dict, ], ) -async def test_list_rag_corpora_rest_asyncio_call_success(request_type): +async def test_upload_rag_file_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10783,38 +12604,35 @@ async def test_list_rag_corpora_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.ListRagCorporaResponse( - next_page_token="next_page_token_value", - ) + return_value = vertex_rag_data_service.UploadRagFileResponse() # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 # Convert return value to protobuf type - return_value = vertex_rag_data_service.ListRagCorporaResponse.pb(return_value) + return_value = vertex_rag_data_service.UploadRagFileResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.list_rag_corpora(request) + response = await client.upload_rag_file(request) # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListRagCorporaAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_list_rag_corpora_rest_asyncio_interceptors(null_interceptor): +async def test_upload_rag_file_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10832,18 +12650,18 @@ async def test_list_rag_corpora_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_list_rag_corpora" + transports.AsyncVertexRagDataServiceRestInterceptor, "post_upload_rag_file" ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_list_rag_corpora_with_metadata", + "post_upload_rag_file_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_list_rag_corpora" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_upload_rag_file" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.ListRagCorporaRequest.pb( - vertex_rag_data_service.ListRagCorporaRequest() + pb_message = vertex_rag_data_service.UploadRagFileRequest.pb( + vertex_rag_data_service.UploadRagFileRequest() ) transcode.return_value = { "method": "post", @@ -10855,24 +12673,24 @@ async def test_list_rag_corpora_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data_service.ListRagCorporaResponse.to_json( - vertex_rag_data_service.ListRagCorporaResponse() + return_value = vertex_rag_data_service.UploadRagFileResponse.to_json( + vertex_rag_data_service.UploadRagFileResponse() ) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.ListRagCorporaRequest() + request = vertex_rag_data_service.UploadRagFileRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data_service.ListRagCorporaResponse() + post.return_value = vertex_rag_data_service.UploadRagFileResponse() post_with_metadata.return_value = ( - vertex_rag_data_service.ListRagCorporaResponse(), + vertex_rag_data_service.UploadRagFileResponse(), metadata, ) - await client.list_rag_corpora( + await client.upload_rag_file( request, metadata=[ ("key", "val"), @@ -10886,8 +12704,8 @@ async def test_list_rag_corpora_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_delete_rag_corpus_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.DeleteRagCorpusRequest, +async def test_import_rag_files_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.ImportRagFilesRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -10897,7 +12715,7 @@ async def test_delete_rag_corpus_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -10911,18 +12729,18 @@ async def test_delete_rag_corpus_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.delete_rag_corpus(request) + await client.import_rag_files(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.DeleteRagCorpusRequest, + vertex_rag_data_service.ImportRagFilesRequest, dict, ], ) -async def test_delete_rag_corpus_rest_asyncio_call_success(request_type): +async def test_import_rag_files_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10932,7 +12750,7 @@ async def test_delete_rag_corpus_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -10949,7 +12767,7 @@ async def test_delete_rag_corpus_rest_asyncio_call_success(request_type): ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.delete_rag_corpus(request) + response = await client.import_rag_files(request) # Establish that the response is the type that we expect. json_return_value = json_format.MessageToJson(return_value) @@ -10957,7 +12775,7 @@ async def test_delete_rag_corpus_rest_asyncio_call_success(request_type): @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_delete_rag_corpus_rest_asyncio_interceptors(null_interceptor): +async def test_import_rag_files_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -10977,18 +12795,18 @@ async def test_delete_rag_corpus_rest_asyncio_interceptors(null_interceptor): ) as transcode, mock.patch.object( operation.Operation, "_set_result_from_operation" ), mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_delete_rag_corpus" + transports.AsyncVertexRagDataServiceRestInterceptor, "post_import_rag_files" ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_delete_rag_corpus_with_metadata", + "post_import_rag_files_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_delete_rag_corpus" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_import_rag_files" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.DeleteRagCorpusRequest.pb( - vertex_rag_data_service.DeleteRagCorpusRequest() + pb_message = vertex_rag_data_service.ImportRagFilesRequest.pb( + vertex_rag_data_service.ImportRagFilesRequest() ) transcode.return_value = { "method": "post", @@ -11003,7 +12821,7 @@ async def test_delete_rag_corpus_rest_asyncio_interceptors(null_interceptor): return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.DeleteRagCorpusRequest() + request = vertex_rag_data_service.ImportRagFilesRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), @@ -11012,7 +12830,7 @@ async def test_delete_rag_corpus_rest_asyncio_interceptors(null_interceptor): post.return_value = operations_pb2.Operation() post_with_metadata.return_value = operations_pb2.Operation(), metadata - await client.delete_rag_corpus( + await client.import_rag_files( request, metadata=[ ("key", "val"), @@ -11026,8 +12844,8 @@ async def test_delete_rag_corpus_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_upload_rag_file_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.UploadRagFileRequest, +async def test_get_rag_file_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.GetRagFileRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -11037,7 +12855,9 @@ async def test_upload_rag_file_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -11051,18 +12871,18 @@ async def test_upload_rag_file_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.upload_rag_file(request) + await client.get_rag_file(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.UploadRagFileRequest, + vertex_rag_data_service.GetRagFileRequest, dict, ], ) -async def test_upload_rag_file_rest_asyncio_call_success(request_type): +async def test_get_rag_file_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11072,35 +12892,44 @@ async def test_upload_rag_file_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.UploadRagFileResponse() + return_value = vertex_rag_data.RagFile( + name="name_value", + display_name="display_name_value", + description="description_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 # Convert return value to protobuf type - return_value = vertex_rag_data_service.UploadRagFileResponse.pb(return_value) + return_value = vertex_rag_data.RagFile.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.upload_rag_file(request) + response = await client.get_rag_file(request) # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data_service.UploadRagFileResponse) + assert isinstance(response, vertex_rag_data.RagFile) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_upload_rag_file_rest_asyncio_interceptors(null_interceptor): +async def test_get_rag_file_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11118,18 +12947,18 @@ async def test_upload_rag_file_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_upload_rag_file" + transports.AsyncVertexRagDataServiceRestInterceptor, "post_get_rag_file" ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_upload_rag_file_with_metadata", + "post_get_rag_file_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_upload_rag_file" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_get_rag_file" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.UploadRagFileRequest.pb( - vertex_rag_data_service.UploadRagFileRequest() + pb_message = vertex_rag_data_service.GetRagFileRequest.pb( + vertex_rag_data_service.GetRagFileRequest() ) transcode.return_value = { "method": "post", @@ -11141,24 +12970,19 @@ async def test_upload_rag_file_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data_service.UploadRagFileResponse.to_json( - vertex_rag_data_service.UploadRagFileResponse() - ) + return_value = vertex_rag_data.RagFile.to_json(vertex_rag_data.RagFile()) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.UploadRagFileRequest() + request = vertex_rag_data_service.GetRagFileRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data_service.UploadRagFileResponse() - post_with_metadata.return_value = ( - vertex_rag_data_service.UploadRagFileResponse(), - metadata, - ) + post.return_value = vertex_rag_data.RagFile() + post_with_metadata.return_value = vertex_rag_data.RagFile(), metadata - await client.upload_rag_file( + await client.get_rag_file( request, metadata=[ ("key", "val"), @@ -11172,8 +12996,8 @@ async def test_upload_rag_file_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_import_rag_files_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.ImportRagFilesRequest, +async def test_list_rag_files_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.ListRagFilesRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -11197,18 +13021,18 @@ async def test_import_rag_files_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.import_rag_files(request) + await client.list_rag_files(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ImportRagFilesRequest, + vertex_rag_data_service.ListRagFilesRequest, dict, ], ) -async def test_import_rag_files_rest_asyncio_call_success(request_type): +async def test_list_rag_files_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11224,26 +13048,32 @@ async def test_import_rag_files_rest_asyncio_call_success(request_type): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data_service.ListRagFilesResponse( + next_page_token="next_page_token_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data_service.ListRagFilesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.import_rag_files(request) + response = await client.list_rag_files(request) # Establish that the response is the type that we expect. - json_return_value = json_format.MessageToJson(return_value) + assert isinstance(response, pagers.ListRagFilesAsyncPager) + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_import_rag_files_rest_asyncio_interceptors(null_interceptor): +async def test_list_rag_files_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11261,20 +13091,18 @@ async def test_import_rag_files_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_import_rag_files" + transports.AsyncVertexRagDataServiceRestInterceptor, "post_list_rag_files" ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_import_rag_files_with_metadata", + "post_list_rag_files_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_import_rag_files" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_list_rag_files" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.ImportRagFilesRequest.pb( - vertex_rag_data_service.ImportRagFilesRequest() + pb_message = vertex_rag_data_service.ListRagFilesRequest.pb( + vertex_rag_data_service.ListRagFilesRequest() ) transcode.return_value = { "method": "post", @@ -11286,19 +13114,24 @@ async def test_import_rag_files_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = json_format.MessageToJson(operations_pb2.Operation()) + return_value = vertex_rag_data_service.ListRagFilesResponse.to_json( + vertex_rag_data_service.ListRagFilesResponse() + ) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.ImportRagFilesRequest() + request = vertex_rag_data_service.ListRagFilesRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() - post_with_metadata.return_value = operations_pb2.Operation(), metadata + post.return_value = vertex_rag_data_service.ListRagFilesResponse() + post_with_metadata.return_value = ( + vertex_rag_data_service.ListRagFilesResponse(), + metadata, + ) - await client.import_rag_files( + await client.list_rag_files( request, metadata=[ ("key", "val"), @@ -11312,8 +13145,8 @@ async def test_import_rag_files_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_get_rag_file_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.GetRagFileRequest, +async def test_delete_rag_file_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.DeleteRagFileRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -11339,18 +13172,18 @@ async def test_get_rag_file_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.get_rag_file(request) + await client.delete_rag_file(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.GetRagFileRequest, + vertex_rag_data_service.DeleteRagFileRequest, dict, ], ) -async def test_get_rag_file_rest_asyncio_call_success(request_type): +async def test_delete_rag_file_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11368,36 +13201,26 @@ async def test_get_rag_file_rest_asyncio_call_success(request_type): # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data.RagFile( - name="name_value", - display_name="display_name_value", - description="description_value", - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data.RagFile.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.get_rag_file(request) + response = await client.delete_rag_file(request) # Establish that the response is the type that we expect. - assert isinstance(response, vertex_rag_data.RagFile) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" + json_return_value = json_format.MessageToJson(return_value) @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_get_rag_file_rest_asyncio_interceptors(null_interceptor): +async def test_delete_rag_file_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11415,18 +13238,20 @@ async def test_get_rag_file_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_get_rag_file" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, "post_delete_rag_file" ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_get_rag_file_with_metadata", + "post_delete_rag_file_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_get_rag_file" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_delete_rag_file" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.GetRagFileRequest.pb( - vertex_rag_data_service.GetRagFileRequest() + pb_message = vertex_rag_data_service.DeleteRagFileRequest.pb( + vertex_rag_data_service.DeleteRagFileRequest() ) transcode.return_value = { "method": "post", @@ -11438,19 +13263,19 @@ async def test_get_rag_file_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data.RagFile.to_json(vertex_rag_data.RagFile()) + return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.GetRagFileRequest() + request = vertex_rag_data_service.DeleteRagFileRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data.RagFile() - post_with_metadata.return_value = vertex_rag_data.RagFile(), metadata + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata - await client.get_rag_file( + await client.delete_rag_file( request, metadata=[ ("key", "val"), @@ -11464,8 +13289,8 @@ async def test_get_rag_file_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_list_rag_files_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.ListRagFilesRequest, +async def test_update_rag_engine_config_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.UpdateRagEngineConfigRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -11475,7 +13300,11 @@ async def test_list_rag_files_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = { + "rag_engine_config": { + "name": "projects/sample1/locations/sample2/ragEngineConfig" + } + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -11489,18 +13318,18 @@ async def test_list_rag_files_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.list_rag_files(request) + await client.update_rag_engine_config(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.ListRagFilesRequest, + vertex_rag_data_service.UpdateRagEngineConfigRequest, dict, ], ) -async def test_list_rag_files_rest_asyncio_call_success(request_type): +async def test_update_rag_engine_config_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11510,38 +13339,109 @@ async def test_list_rag_files_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2/ragCorpora/sample3"} + request_init = { + "rag_engine_config": { + "name": "projects/sample1/locations/sample2/ragEngineConfig" + } + } + request_init["rag_engine_config"] = { + "name": "projects/sample1/locations/sample2/ragEngineConfig", + "rag_managed_db_config": {"scaled": {}, "basic": {}, "unprovisioned": {}}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = vertex_rag_data_service.UpdateRagEngineConfigRequest.meta.fields[ + "rag_engine_config" + ] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["rag_engine_config"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["rag_engine_config"][field])): + del request_init["rag_engine_config"][field][i][subfield] + else: + del request_init["rag_engine_config"][field][subfield] request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = vertex_rag_data_service.ListRagFilesResponse( - next_page_token="next_page_token_value", - ) + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 - - # Convert return value to protobuf type - return_value = vertex_rag_data_service.ListRagFilesResponse.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.list_rag_files(request) + response = await client.update_rag_engine_config(request) # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListRagFilesAsyncPager) - assert response.next_page_token == "next_page_token_value" + json_return_value = json_format.MessageToJson(return_value) @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_list_rag_files_rest_asyncio_interceptors(null_interceptor): +async def test_update_rag_engine_config_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11559,18 +13459,22 @@ async def test_list_rag_files_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_list_rag_files" + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.AsyncVertexRagDataServiceRestInterceptor, + "post_update_rag_engine_config", ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_list_rag_files_with_metadata", + "post_update_rag_engine_config_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_list_rag_files" + transports.AsyncVertexRagDataServiceRestInterceptor, + "pre_update_rag_engine_config", ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.ListRagFilesRequest.pb( - vertex_rag_data_service.ListRagFilesRequest() + pb_message = vertex_rag_data_service.UpdateRagEngineConfigRequest.pb( + vertex_rag_data_service.UpdateRagEngineConfigRequest() ) transcode.return_value = { "method": "post", @@ -11582,24 +13486,19 @@ async def test_list_rag_files_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = vertex_rag_data_service.ListRagFilesResponse.to_json( - vertex_rag_data_service.ListRagFilesResponse() - ) + return_value = json_format.MessageToJson(operations_pb2.Operation()) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.ListRagFilesRequest() + request = vertex_rag_data_service.UpdateRagEngineConfigRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = vertex_rag_data_service.ListRagFilesResponse() - post_with_metadata.return_value = ( - vertex_rag_data_service.ListRagFilesResponse(), - metadata, - ) + post.return_value = operations_pb2.Operation() + post_with_metadata.return_value = operations_pb2.Operation(), metadata - await client.list_rag_files( + await client.update_rag_engine_config( request, metadata=[ ("key", "val"), @@ -11613,8 +13512,8 @@ async def test_list_rag_files_rest_asyncio_interceptors(null_interceptor): @pytest.mark.asyncio -async def test_delete_rag_file_rest_asyncio_bad_request( - request_type=vertex_rag_data_service.DeleteRagFileRequest, +async def test_get_rag_engine_config_rest_asyncio_bad_request( + request_type=vertex_rag_data_service.GetRagEngineConfigRequest, ): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -11624,9 +13523,7 @@ async def test_delete_rag_file_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" - } + request_init = {"name": "projects/sample1/locations/sample2/ragEngineConfig"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -11640,18 +13537,18 @@ async def test_delete_rag_file_rest_asyncio_bad_request( response_value.request = mock.Mock() req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - await client.delete_rag_file(request) + await client.get_rag_engine_config(request) @pytest.mark.asyncio @pytest.mark.parametrize( "request_type", [ - vertex_rag_data_service.DeleteRagFileRequest, + vertex_rag_data_service.GetRagEngineConfigRequest, dict, ], ) -async def test_delete_rag_file_rest_asyncio_call_success(request_type): +async def test_get_rag_engine_config_rest_asyncio_call_success(request_type): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11661,34 +13558,38 @@ async def test_delete_rag_file_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = { - "name": "projects/sample1/locations/sample2/ragCorpora/sample3/ragFiles/sample4" - } + request_init = {"name": "projects/sample1/locations/sample2/ragEngineConfig"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name="operations/spam") + return_value = vertex_rag_data.RagEngineConfig( + name="name_value", + ) # Wrap the value into a proper Response obj response_value = mock.Mock() response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = vertex_rag_data.RagEngineConfig.pb(return_value) json_return_value = json_format.MessageToJson(return_value) response_value.read = mock.AsyncMock( return_value=json_return_value.encode("UTF-8") ) req.return_value = response_value req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - response = await client.delete_rag_file(request) + response = await client.get_rag_engine_config(request) # Establish that the response is the type that we expect. - json_return_value = json_format.MessageToJson(return_value) + assert isinstance(response, vertex_rag_data.RagEngineConfig) + assert response.name == "name_value" @pytest.mark.asyncio @pytest.mark.parametrize("null_interceptor", [True, False]) -async def test_delete_rag_file_rest_asyncio_interceptors(null_interceptor): +async def test_get_rag_engine_config_rest_asyncio_interceptors(null_interceptor): if not HAS_ASYNC_REST_EXTRA: pytest.skip( "the library must be installed with the `async_rest` extra to test this feature." @@ -11706,20 +13607,19 @@ async def test_delete_rag_file_rest_asyncio_interceptors(null_interceptor): ) as req, mock.patch.object( path_template, "transcode" ) as transcode, mock.patch.object( - operation.Operation, "_set_result_from_operation" - ), mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "post_delete_rag_file" + transports.AsyncVertexRagDataServiceRestInterceptor, + "post_get_rag_engine_config", ) as post, mock.patch.object( transports.AsyncVertexRagDataServiceRestInterceptor, - "post_delete_rag_file_with_metadata", + "post_get_rag_engine_config_with_metadata", ) as post_with_metadata, mock.patch.object( - transports.AsyncVertexRagDataServiceRestInterceptor, "pre_delete_rag_file" + transports.AsyncVertexRagDataServiceRestInterceptor, "pre_get_rag_engine_config" ) as pre: pre.assert_not_called() post.assert_not_called() post_with_metadata.assert_not_called() - pb_message = vertex_rag_data_service.DeleteRagFileRequest.pb( - vertex_rag_data_service.DeleteRagFileRequest() + pb_message = vertex_rag_data_service.GetRagEngineConfigRequest.pb( + vertex_rag_data_service.GetRagEngineConfigRequest() ) transcode.return_value = { "method": "post", @@ -11731,19 +13631,21 @@ async def test_delete_rag_file_rest_asyncio_interceptors(null_interceptor): req.return_value = mock.Mock() req.return_value.status_code = 200 req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} - return_value = json_format.MessageToJson(operations_pb2.Operation()) + return_value = vertex_rag_data.RagEngineConfig.to_json( + vertex_rag_data.RagEngineConfig() + ) req.return_value.read = mock.AsyncMock(return_value=return_value) - request = vertex_rag_data_service.DeleteRagFileRequest() + request = vertex_rag_data_service.GetRagEngineConfigRequest() metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - post.return_value = operations_pb2.Operation() - post_with_metadata.return_value = operations_pb2.Operation(), metadata + post.return_value = vertex_rag_data.RagEngineConfig() + post_with_metadata.return_value = vertex_rag_data.RagEngineConfig(), metadata - await client.delete_rag_file( + await client.get_rag_engine_config( request, metadata=[ ("key", "val"), @@ -12760,6 +14662,60 @@ async def test_delete_rag_file_empty_call_rest_asyncio(): assert args[0] == request_msg +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_update_rag_engine_config_empty_call_rest_asyncio(): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="rest_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_rag_engine_config), "__call__" + ) as call: + await client.update_rag_engine_config(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.UpdateRagEngineConfigRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_rag_engine_config_empty_call_rest_asyncio(): + if not HAS_ASYNC_REST_EXTRA: + pytest.skip( + "the library must be installed with the `async_rest` extra to test this feature." + ) + client = VertexRagDataServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="rest_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.get_rag_engine_config), "__call__" + ) as call: + await client.get_rag_engine_config(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = vertex_rag_data_service.GetRagEngineConfigRequest() + + assert args[0] == request_msg + + def test_vertex_rag_data_service_rest_asyncio_lro_client(): if not HAS_ASYNC_REST_EXTRA: pytest.skip( @@ -12838,6 +14794,8 @@ def test_vertex_rag_data_service_base_transport(): "get_rag_file", "list_rag_files", "delete_rag_file", + "update_rag_engine_config", + "get_rag_engine_config", "set_iam_policy", "get_iam_policy", "test_iam_permissions", @@ -13145,6 +15103,12 @@ def test_vertex_rag_data_service_client_transport_session_collision(transport_na session1 = client1.transport.delete_rag_file._session session2 = client2.transport.delete_rag_file._session assert session1 != session2 + session1 = client1.transport.update_rag_engine_config._session + session2 = client2.transport.update_rag_engine_config._session + assert session1 != session2 + session1 = client1.transport.get_rag_engine_config._session + session2 = client2.transport.get_rag_engine_config._session + assert session1 != session2 def test_vertex_rag_data_service_grpc_transport_channel(): @@ -13385,11 +15349,34 @@ def test_parse_rag_corpus_path(): assert expected == actual -def test_rag_file_path(): +def test_rag_engine_config_path(): project = "cuttlefish" location = "mussel" - rag_corpus = "winkle" - rag_file = "nautilus" + expected = "projects/{project}/locations/{location}/ragEngineConfig".format( + project=project, + location=location, + ) + actual = VertexRagDataServiceClient.rag_engine_config_path(project, location) + assert expected == actual + + +def test_parse_rag_engine_config_path(): + expected = { + "project": "winkle", + "location": "nautilus", + } + path = VertexRagDataServiceClient.rag_engine_config_path(**expected) + + # Check that the path construction is reversible. + actual = VertexRagDataServiceClient.parse_rag_engine_config_path(path) + assert expected == actual + + +def test_rag_file_path(): + project = "scallop" + location = "abalone" + rag_corpus = "squid" + rag_file = "clam" expected = "projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}".format( project=project, location=location, @@ -13404,10 +15391,10 @@ def test_rag_file_path(): def test_parse_rag_file_path(): expected = { - "project": "scallop", - "location": "abalone", - "rag_corpus": "squid", - "rag_file": "clam", + "project": "whelk", + "location": "octopus", + "rag_corpus": "oyster", + "rag_file": "nudibranch", } path = VertexRagDataServiceClient.rag_file_path(**expected) @@ -13417,9 +15404,9 @@ def test_parse_rag_file_path(): def test_secret_version_path(): - project = "whelk" - secret = "octopus" - secret_version = "oyster" + project = "cuttlefish" + secret = "mussel" + secret_version = "winkle" expected = "projects/{project}/secrets/{secret}/versions/{secret_version}".format( project=project, secret=secret, @@ -13433,9 +15420,9 @@ def test_secret_version_path(): def test_parse_secret_version_path(): expected = { - "project": "nudibranch", - "secret": "cuttlefish", - "secret_version": "mussel", + "project": "nautilus", + "secret": "scallop", + "secret_version": "abalone", } path = VertexRagDataServiceClient.secret_version_path(**expected) @@ -13445,7 +15432,7 @@ def test_parse_secret_version_path(): def test_common_billing_account_path(): - billing_account = "winkle" + billing_account = "squid" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -13455,7 +15442,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "clam", } path = VertexRagDataServiceClient.common_billing_account_path(**expected) @@ -13465,7 +15452,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "scallop" + folder = "whelk" expected = "folders/{folder}".format( folder=folder, ) @@ -13475,7 +15462,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "octopus", } path = VertexRagDataServiceClient.common_folder_path(**expected) @@ -13485,7 +15472,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "squid" + organization = "oyster" expected = "organizations/{organization}".format( organization=organization, ) @@ -13495,7 +15482,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "nudibranch", } path = VertexRagDataServiceClient.common_organization_path(**expected) @@ -13505,7 +15492,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "whelk" + project = "cuttlefish" expected = "projects/{project}".format( project=project, ) @@ -13515,7 +15502,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "mussel", } path = VertexRagDataServiceClient.common_project_path(**expected) @@ -13525,8 +15512,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "oyster" - location = "nudibranch" + project = "winkle" + location = "nautilus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -13537,8 +15524,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "scallop", + "location": "abalone", } path = VertexRagDataServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index feb94df682..41e21494a6 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -24446,7 +24446,16 @@ def test_create_custom_job_rest_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "base_output_directory": {"output_uri_prefix": "output_uri_prefix_value"}, "protected_artifact_location_id": "protected_artifact_location_id_value", "tensorboard": "tensorboard_value", @@ -26110,7 +26119,16 @@ def test_create_hyperparameter_tuning_job_rest_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "base_output_directory": {"output_uri_prefix": "output_uri_prefix_value"}, "protected_artifact_location_id": "protected_artifact_location_id_value", "tensorboard": "tensorboard_value", @@ -26989,7 +27007,14 @@ def test_create_nas_job_rest_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "base_output_directory": { "output_uri_prefix": "output_uri_prefix_value" @@ -31943,7 +31968,16 @@ async def test_create_custom_job_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "base_output_directory": {"output_uri_prefix": "output_uri_prefix_value"}, "protected_artifact_location_id": "protected_artifact_location_id_value", "tensorboard": "tensorboard_value", @@ -33772,7 +33806,16 @@ async def test_create_hyperparameter_tuning_job_rest_asyncio_call_success(reques "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "base_output_directory": {"output_uri_prefix": "output_uri_prefix_value"}, "protected_artifact_location_id": "protected_artifact_location_id_value", "tensorboard": "tensorboard_value", @@ -34749,7 +34792,14 @@ async def test_create_nas_job_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "base_output_directory": { "output_uri_prefix": "output_uri_prefix_value" diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py index 252654a36e..88e61962e2 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_persistent_resource_service.py @@ -5349,7 +5349,16 @@ def test_create_persistent_resource_rest_call_success(request_type): "update_time": {}, "labels": {}, "network": "network_value", - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "resource_runtime_spec": { "service_account_spec": { @@ -6055,7 +6064,16 @@ def test_update_persistent_resource_rest_call_success(request_type): "update_time": {}, "labels": {}, "network": "network_value", - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "resource_runtime_spec": { "service_account_spec": { @@ -7262,7 +7280,16 @@ async def test_create_persistent_resource_rest_asyncio_call_success(request_type "update_time": {}, "labels": {}, "network": "network_value", - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "resource_runtime_spec": { "service_account_spec": { @@ -8032,7 +8059,16 @@ async def test_update_persistent_resource_rest_asyncio_call_success(request_type "update_time": {}, "labels": {}, "network": "network_value", - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "encryption_spec": {"kms_key_name": "kms_key_name_value"}, "resource_runtime_spec": { "service_account_spec": { diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 5991794864..37ed2552e3 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -10113,7 +10113,16 @@ def test_create_pipeline_job_rest_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, "schedule_name": "schedule_name_value", @@ -13155,7 +13164,16 @@ async def test_create_pipeline_job_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value1", "reserved_ip_ranges_value2", ], - "psc_interface_config": {"network_attachment": "network_attachment_value"}, + "psc_interface_config": { + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], + }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, "schedule_name": "schedule_name_value", diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py index fe5f36d61e..d2c850a57d 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_schedule_service.py @@ -5723,7 +5723,14 @@ def test_create_schedule_rest_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, @@ -6903,7 +6910,14 @@ def test_update_schedule_rest_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, @@ -8273,7 +8287,14 @@ async def test_create_schedule_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, @@ -9553,7 +9574,14 @@ async def test_update_schedule_rest_asyncio_call_success(request_type): "reserved_ip_ranges_value2", ], "psc_interface_config": { - "network_attachment": "network_attachment_value" + "network_attachment": "network_attachment_value", + "dns_peering_configs": [ + { + "domain": "domain_value", + "target_project": "target_project_value", + "target_network": "target_network_value", + } + ], }, "template_uri": "template_uri_value", "template_metadata": {"version": "version_value"}, diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py index 91798a544d..78758d0be1 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vertex_rag_data_service.py @@ -10512,7 +10512,12 @@ def test_update_rag_engine_config_rest_call_success(request_type): } request_init["rag_engine_config"] = { "name": "projects/sample1/locations/sample2/ragEngineConfig", - "rag_managed_db_config": {"enterprise": {}, "basic": {}}, + "rag_managed_db_config": { + "enterprise": {}, + "scaled": {}, + "basic": {}, + "unprovisioned": {}, + }, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -13486,7 +13491,12 @@ async def test_update_rag_engine_config_rest_asyncio_call_success(request_type): } request_init["rag_engine_config"] = { "name": "projects/sample1/locations/sample2/ragEngineConfig", - "rag_managed_db_config": {"enterprise": {}, "basic": {}}, + "rag_managed_db_config": { + "enterprise": {}, + "scaled": {}, + "basic": {}, + "unprovisioned": {}, + }, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency From 4df909cdfddf071c4b87a7d1dabed56437d528a2 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 17 Jun 2025 16:58:15 -0700 Subject: [PATCH 20/24] feat: Added `autoscaling_target_request_count_per_minute` to model deployment on Endpoint and Model classes PiperOrigin-RevId: 772677877 --- google/cloud/aiplatform/models.py | 20 +++++++++ tests/unit/aiplatform/test_endpoints.py | 52 ++++++++++++++++++++++- tests/unit/aiplatform/test_models.py | 56 +++++++++++++++++++++++++ 3 files changed, 127 insertions(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index 7768afe57c..3ff1a6dd6a 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -1363,6 +1363,7 @@ def deploy( deploy_request_timeout: Optional[float] = None, autoscaling_target_cpu_utilization: Optional[int] = None, autoscaling_target_accelerator_duty_cycle: Optional[int] = None, + autoscaling_target_request_count_per_minute: Optional[int] = None, enable_access_logging=False, disable_container_logging: bool = False, deployment_resource_pool: Optional[DeploymentResourcePool] = None, @@ -1456,6 +1457,9 @@ def deploy( Target Accelerator Duty Cycle. Must also set accelerator_type and accelerator_count if specified. A default value of 60 will be used if not specified. + autoscaling_target_request_count_per_minute (int): + Optional. The target number of requests per minute for autoscaling. + If set, the model will be scaled based on the number of requests it receives. enable_access_logging (bool): Whether to enable endpoint access logging. Defaults to False. disable_container_logging (bool): @@ -1536,6 +1540,7 @@ def deploy( deploy_request_timeout=deploy_request_timeout, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + autoscaling_target_request_count_per_minute=autoscaling_target_request_count_per_minute, spot=spot, enable_access_logging=enable_access_logging, disable_container_logging=disable_container_logging, @@ -1568,6 +1573,7 @@ def _deploy( deploy_request_timeout: Optional[float] = None, autoscaling_target_cpu_utilization: Optional[int] = None, autoscaling_target_accelerator_duty_cycle: Optional[int] = None, + autoscaling_target_request_count_per_minute: Optional[int] = None, spot: bool = False, enable_access_logging=False, disable_container_logging: bool = False, @@ -1664,6 +1670,9 @@ def _deploy( Target Accelerator Duty Cycle. Must also set accelerator_type and accelerator_count if specified. A default value of 60 will be used if not specified. + autoscaling_target_request_count_per_minute (int): + Optional. The target number of requests per minute for autoscaling. + If set, the model will be scaled based on the number of requests it receives. spot (bool): Optional. Whether to schedule the deployment workload on spot VMs. enable_access_logging (bool): @@ -1721,6 +1730,7 @@ def _deploy( deploy_request_timeout=deploy_request_timeout, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + autoscaling_target_request_count_per_minute=autoscaling_target_request_count_per_minute, spot=spot, enable_access_logging=enable_access_logging, disable_container_logging=disable_container_logging, @@ -5339,6 +5349,7 @@ def deploy( deploy_request_timeout: Optional[float] = None, autoscaling_target_cpu_utilization: Optional[int] = None, autoscaling_target_accelerator_duty_cycle: Optional[int] = None, + autoscaling_target_request_count_per_minute: Optional[int] = None, enable_access_logging=False, disable_container_logging: bool = False, private_service_connect_config: Optional[ @@ -5454,6 +5465,9 @@ def deploy( Optional. Target Accelerator Duty Cycle. Must also set accelerator_type and accelerator_count if specified. A default value of 60 will be used if not specified. + autoscaling_target_request_count_per_minute (int): + Optional. The target number of requests per minute for autoscaling. + If set, the model will be scaled based on the number of requests it receives. enable_access_logging (bool): Whether to enable endpoint access logging. Defaults to False. disable_container_logging (bool): @@ -5561,6 +5575,7 @@ def deploy( deploy_request_timeout=deploy_request_timeout, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + autoscaling_target_request_count_per_minute=autoscaling_target_request_count_per_minute, spot=spot, enable_access_logging=enable_access_logging, disable_container_logging=disable_container_logging, @@ -5603,6 +5618,7 @@ def _deploy( deploy_request_timeout: Optional[float] = None, autoscaling_target_cpu_utilization: Optional[int] = None, autoscaling_target_accelerator_duty_cycle: Optional[int] = None, + autoscaling_target_request_count_per_minute: Optional[int] = None, spot: bool = False, enable_access_logging=False, disable_container_logging: bool = False, @@ -5720,6 +5736,9 @@ def _deploy( Optional. Target Accelerator Duty Cycle. Must also set accelerator_type and accelerator_count if specified. A default value of 60 will be used if not specified. + autoscaling_target_request_count_per_minute (int): + Optional. The target number of requests per minute for autoscaling. + If set, the model will be scaled based on the number of requests it receives. spot (bool): Optional. Whether to schedule the deployment workload on spot VMs. enable_access_logging (bool): @@ -5808,6 +5827,7 @@ def _deploy( deploy_request_timeout=deploy_request_timeout, autoscaling_target_cpu_utilization=autoscaling_target_cpu_utilization, autoscaling_target_accelerator_duty_cycle=autoscaling_target_accelerator_duty_cycle, + autoscaling_target_request_count_per_minute=autoscaling_target_request_count_per_minute, spot=spot, enable_access_logging=enable_access_logging, disable_container_logging=disable_container_logging, diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index e6297183c9..84de3c6282 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -1917,11 +1917,61 @@ def test_deploy_with_autoscaling_target_accelerator_duty_cycle_and_no_accelerato if not sync: test_endpoint.wait() + @pytest.mark.usefixtures("get_endpoint_mock", "get_model_mock") + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_with_autoscaling_target_request_count_per_minute( + self, deploy_model_mock, sync + ): + test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + test_model = models.Model(_TEST_ID) + test_model._gca_resource.supported_deployment_resources_types.append( + aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ) + test_endpoint.deploy( + model=test_model, + machine_type=_TEST_MACHINE_TYPE, + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + deploy_request_timeout=None, + autoscaling_target_request_count_per_minute=600, + ) + + if not sync: + test_endpoint.wait() + + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + ), + min_replica_count=1, + max_replica_count=1, + autoscaling_metric_specs=[ + gca_machine_resources.AutoscalingMetricSpec( + metric_name=_TEST_METRIC_NAME_REQUEST_COUNT, + target=600, + ), + ], + ) + + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + service_account=_TEST_SERVICE_ACCOUNT, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + timeout=None, + ) + @pytest.mark.usefixtures( "get_endpoint_mock", "get_model_mock", "preview_deploy_model_mock" ) @pytest.mark.parametrize("sync", [True, False]) - def test_deploy_with_autoscaling_target_request_count_per_minute( + def test_deploy_with_autoscaling_target_request_count_per_minute_preview( self, preview_deploy_model_mock, sync ): test_endpoint = preview_models.Endpoint(_TEST_ENDPOINT_NAME) diff --git a/tests/unit/aiplatform/test_models.py b/tests/unit/aiplatform/test_models.py index ef06ecae76..ba95f42717 100644 --- a/tests/unit/aiplatform/test_models.py +++ b/tests/unit/aiplatform/test_models.py @@ -2386,6 +2386,62 @@ def test_deploy_no_endpoint_dedicated_resources_autoscaling_accelerator_duty_cyc if not sync: test_endpoint.wait() + @pytest.mark.usefixtures( + "get_model_mock", + "create_endpoint_mock", + "get_endpoint_mock", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_deploy_no_endpoint_dedicated_resources_autoscaling_request_count_per_minute( + self, deploy_model_mock, sync + ): + test_model = models.Model(_TEST_ID) + test_model._gca_resource.supported_deployment_resources_types.append( + aiplatform.gapic.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ) + + test_endpoint = test_model.deploy( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + sync=sync, + deploy_request_timeout=None, + system_labels=_TEST_LABELS, + autoscaling_target_request_count_per_minute=600, + ) + + if not sync: + test_endpoint.wait() + + expected_dedicated_resources = gca_machine_resources.DedicatedResources( + machine_spec=gca_machine_resources.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + ), + min_replica_count=1, + max_replica_count=1, + autoscaling_metric_specs=[ + gca_machine_resources.AutoscalingMetricSpec( + metric_name=_TEST_METRIC_NAME_REQUEST_COUNT, + target=600, + ), + ], + ) + expected_deployed_model = gca_endpoint.DeployedModel( + dedicated_resources=expected_dedicated_resources, + model=test_model.resource_name, + display_name=None, + system_labels=_TEST_LABELS, + ) + deploy_model_mock.assert_called_once_with( + endpoint=test_endpoint.resource_name, + deployed_model=expected_deployed_model, + traffic_split={"0": 100}, + metadata=(), + timeout=None, + ) + @pytest.mark.usefixtures( "get_model_mock", "create_endpoint_mock", From 726d3a26c81b95163416af5b3da2367086317573 Mon Sep 17 00:00:00 2001 From: Darshan Mehta Date: Wed, 18 Jun 2025 11:45:17 -0700 Subject: [PATCH 21/24] feat: RAG - Add Scaled and Unprovisioned tier in preview. feat: RAG - Implement v1 `update_rag_engine_config` in `rag_data.py` feat: RAG - Implement v1 `get_rag_engine_config` in `rag_data.py` feat: RAG - Add Basic, Scaled and Unprovisioned tier in v1. PiperOrigin-RevId: 773003445 --- tests/unit/vertex_rag/test_rag_constants.py | 46 ++++ .../vertex_rag/test_rag_constants_preview.py | 22 ++ tests/unit/vertex_rag/test_rag_data.py | 206 ++++++++++++++++++ .../unit/vertex_rag/test_rag_data_preview.py | 120 +++++++++- vertexai/preview/rag/__init__.py | 4 + vertexai/preview/rag/utils/_gapic_utils.py | 14 +- vertexai/preview/rag/utils/resources.py | 30 ++- vertexai/rag/__init__.py | 34 ++- vertexai/rag/rag_data.py | 73 +++++++ vertexai/rag/utils/_gapic_utils.py | 59 ++++- vertexai/rag/utils/resources.py | 61 ++++++ 11 files changed, 647 insertions(+), 22 deletions(-) diff --git a/tests/unit/vertex_rag/test_rag_constants.py b/tests/unit/vertex_rag/test_rag_constants.py index f99e042b7b..65459cec2a 100644 --- a/tests/unit/vertex_rag/test_rag_constants.py +++ b/tests/unit/vertex_rag/test_rag_constants.py @@ -19,24 +19,29 @@ from google.cloud import aiplatform from vertexai.rag import ( + Basic, Filter, LayoutParserConfig, LlmParserConfig, LlmRanker, Pinecone, RagCorpus, + RagEngineConfig, RagFile, + RagManagedDbConfig, RagResource, RagRetrievalConfig, RagVectorDbConfig, Ranking, RankService, + Scaled, SharePointSource, SharePointSources, SlackChannelsSource, SlackChannel, JiraSource, JiraQuery, + Unprovisioned, VertexVectorSearch, RagEmbeddingModelConfig, VertexAiSearchConfig, @@ -45,9 +50,11 @@ from google.cloud.aiplatform_v1 import ( GoogleDriveSource, + RagEngineConfig as GapicRagEngineConfig, RagFileChunkingConfig, RagFileParsingConfig, RagFileTransformationConfig, + RagManagedDbConfig as GapicRagManagedDbConfig, ImportRagFilesConfig, ImportRagFilesRequest, ImportRagFilesResponse, @@ -677,6 +684,45 @@ import_rag_files_config=TEST_IMPORT_FILES_CONFIG_LLM_PARSER, ) +# RagEngineConfig Resource +TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME = ( + f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragEngineConfig" +) +TEST_RAG_ENGINE_CONFIG_BASIC = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig(tier=Basic()), +) +TEST_RAG_ENGINE_CONFIG_SCALED = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig(tier=Scaled()), +) +TEST_RAG_ENGINE_CONFIG_UNPROVISIONED = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig(tier=Unprovisioned()), +) +TEST_DEFAULT_RAG_ENGINE_CONFIG = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=None, +) +TEST_GAPIC_RAG_ENGINE_CONFIG_BASIC = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + basic=GapicRagManagedDbConfig.Basic() + ), +) +TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + scaled=GapicRagManagedDbConfig.Scaled() + ), +) +TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + unprovisioned=GapicRagManagedDbConfig.Unprovisioned() + ), +) + # Inline Citations test constants TEST_ORIGINAL_TEXT = ( "You can activate the parking radar using a switch or through the" diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index 08a8896f79..9fd79fe457 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -67,10 +67,12 @@ RagVectorDbConfig, RankService, Ranking, + Scaled, SharePointSource, SharePointSources, SlackChannel, SlackChannelsSource, + Unprovisioned, VertexAiSearchConfig, VertexFeatureStore, VertexPredictionEndpoint, @@ -561,6 +563,14 @@ name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, rag_managed_db_config=RagManagedDbConfig(tier=Basic()), ) +TEST_RAG_ENGINE_CONFIG_SCALED = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig(tier=Scaled()), +) +TEST_RAG_ENGINE_CONFIG_UNPROVISIONED = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig(tier=Unprovisioned()), +) TEST_RAG_ENGINE_CONFIG_ENTERPRISE = RagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, rag_managed_db_config=RagManagedDbConfig(tier=Enterprise()), @@ -575,6 +585,18 @@ basic=GapicRagManagedDbConfig.Basic() ), ) +TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + scaled=GapicRagManagedDbConfig.Scaled() + ), +) +TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED = GapicRagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=GapicRagManagedDbConfig( + unprovisioned=GapicRagManagedDbConfig.Unprovisioned() + ), +) TEST_GAPIC_RAG_ENGINE_CONFIG_ENTERPRISE = GapicRagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, rag_managed_db_config=GapicRagManagedDbConfig( diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index da50f63145..94f6c35bf9 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -206,6 +206,113 @@ def list_rag_corpora_pager_mock(): yield list_rag_corpora_pager_mock +@pytest.fixture() +def update_rag_engine_config_basic_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_basic_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_BASIC + ) + update_rag_engine_config_basic_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_basic_mock + + +@pytest.fixture() +def update_rag_engine_config_scaled_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_scaled_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED + ) + update_rag_engine_config_scaled_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_scaled_mock + + +@pytest.fixture() +def update_rag_engine_config_unprovisioned_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_unprovisioned_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED + ) + update_rag_engine_config_unprovisioned_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_unprovisioned_mock + + +@pytest.fixture() +def update_rag_engine_config_mock_exception(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_mock_exception: + update_rag_engine_config_mock_exception.side_effect = Exception + yield update_rag_engine_config_mock_exception + + +@pytest.fixture() +def get_rag_engine_basic_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_basic_config_mock: + get_rag_engine_basic_config_mock.return_value = ( + test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_BASIC + ) + yield get_rag_engine_basic_config_mock + + +@pytest.fixture() +def get_rag_engine_scaled_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_scaled_config_mock: + get_rag_engine_scaled_config_mock.return_value = ( + test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED + ) + yield get_rag_engine_scaled_config_mock + + +@pytest.fixture() +def get_rag_engine_unprovisioned_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_unprovisioned_config_mock: + get_rag_engine_unprovisioned_config_mock.return_value = ( + test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED + ) + yield get_rag_engine_unprovisioned_config_mock + + +@pytest.fixture() +def get_rag_engine_config_mock_exception(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_config_mock_exception: + get_rag_engine_config_mock_exception.side_effect = Exception + yield get_rag_engine_config_mock_exception + + class MockResponse: def __init__(self, json_data, status_code): self.json_data = json_data @@ -355,6 +462,13 @@ def import_files_request_eq(returned_request, expected_request): ) +def rag_engine_config_eq(returned_config, expected_config): + assert returned_config.name == expected_config.name + assert returned_config.rag_managed_db_config.__eq__( + expected_config.rag_managed_db_config + ) + + @pytest.mark.usefixtures("google_auth_mock") class TestRagDataManagement: def setup_method(self): @@ -1084,3 +1198,95 @@ def test_set_embedding_model_config_wrong_endpoint_format_error(self): test_rag_constants.TEST_GAPIC_RAG_CORPUS, ) e.match("endpoint must be of the format ") + + def test_update_rag_engine_config_success( + self, update_rag_engine_config_basic_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants.TEST_RAG_ENGINE_CONFIG_BASIC, + ) + assert update_rag_engine_config_basic_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants.TEST_RAG_ENGINE_CONFIG_BASIC, + ) + + def test_update_rag_engine_config_scaled_success( + self, update_rag_engine_config_scaled_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants.TEST_RAG_ENGINE_CONFIG_SCALED, + ) + assert update_rag_engine_config_scaled_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants.TEST_RAG_ENGINE_CONFIG_SCALED, + ) + + def test_update_rag_engine_config_unprovisioned_success( + self, update_rag_engine_config_unprovisioned_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED, + ) + assert update_rag_engine_config_unprovisioned_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED, + ) + + @pytest.mark.usefixtures("update_rag_engine_config_mock_exception") + def test_update_rag_engine_config_failure(self): + with pytest.raises(RuntimeError) as e: + rag.update_rag_engine_config( + rag_engine_config=test_rag_constants.TEST_RAG_ENGINE_CONFIG_SCALED, + ) + e.match("Failed in RagEngineConfig update due to") + + @pytest.mark.usefixtures("update_rag_engine_config_basic_mock") + def test_update_rag_engine_config_bad_input( + self, update_rag_engine_config_basic_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants.TEST_DEFAULT_RAG_ENGINE_CONFIG, + ) + assert update_rag_engine_config_basic_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants.TEST_RAG_ENGINE_CONFIG_BASIC, + ) + + @pytest.mark.usefixtures("get_rag_engine_basic_config_mock") + def test_get_rag_engine_config_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, test_rag_constants.TEST_RAG_ENGINE_CONFIG_BASIC + ) + + @pytest.mark.usefixtures("get_rag_engine_scaled_config_mock") + def test_get_rag_engine_config_scaled_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, test_rag_constants.TEST_RAG_ENGINE_CONFIG_SCALED + ) + + @pytest.mark.usefixtures("get_rag_engine_unprovisioned_config_mock") + def test_get_rag_engine_config_unprovisioned_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, test_rag_constants.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED + ) + + @pytest.mark.usefixtures("get_rag_engine_config_mock_exception") + def test_get_rag_engine_config_failure(self): + with pytest.raises(RuntimeError) as e: + rag.get_rag_engine_config( + name=test_rag_constants.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + e.match("Failed in getting the RagEngineConfig due to") diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index cdb4c5823e..cbd6bdd30c 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -475,6 +475,40 @@ def update_rag_engine_config_enterprise_mock(): yield update_rag_engine_config_enterprise_mock +@pytest.fixture() +def update_rag_engine_config_scaled_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_scaled_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED + ) + update_rag_engine_config_scaled_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_scaled_mock + + +@pytest.fixture() +def update_rag_engine_config_unprovisioned_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_unprovisioned_mock: + update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_engine_config_lro_mock.done.return_value = True + update_rag_engine_config_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED + ) + update_rag_engine_config_unprovisioned_mock.return_value = ( + update_rag_engine_config_lro_mock + ) + yield update_rag_engine_config_unprovisioned_mock + + @pytest.fixture() def update_rag_engine_config_mock_exception(): with mock.patch.object( @@ -497,6 +531,30 @@ def get_rag_engine_basic_config_mock(): yield get_rag_engine_basic_config_mock +@pytest.fixture() +def get_rag_engine_scaled_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_scaled_config_mock: + get_rag_engine_scaled_config_mock.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED + ) + yield get_rag_engine_scaled_config_mock + + +@pytest.fixture() +def get_rag_engine_unprovisioned_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "get_rag_engine_config", + ) as get_rag_engine_unprovisioned_config_mock: + get_rag_engine_unprovisioned_config_mock.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED + ) + yield get_rag_engine_unprovisioned_config_mock + + @pytest.fixture() def get_rag_engine_enterprise_config_mock(): with mock.patch.object( @@ -1642,6 +1700,42 @@ def test_update_rag_engine_config_success( test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_BASIC, ) + def test_update_rag_engine_config_enterprise_success( + self, update_rag_engine_config_enterprise_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_ENTERPRISE, + ) + assert update_rag_engine_config_enterprise_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_ENTERPRISE, + ) + + def test_update_rag_engine_config_scaled_success( + self, update_rag_engine_config_scaled_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SCALED, + ) + assert update_rag_engine_config_scaled_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SCALED, + ) + + def test_update_rag_engine_config_unprovisioned_success( + self, update_rag_engine_config_unprovisioned_mock + ): + rag_config = rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED, + ) + assert update_rag_engine_config_unprovisioned_mock.call_count == 1 + rag_engine_config_eq( + rag_config, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED, + ) + @pytest.mark.usefixtures("update_rag_engine_config_mock_exception") def test_update_rag_engine_config_failure(self): with pytest.raises(RuntimeError) as e: @@ -1650,17 +1744,17 @@ def test_update_rag_engine_config_failure(self): ) e.match("Failed in RagEngineConfig update due to") - @pytest.mark.usefixtures("update_rag_engine_config_enterprise_mock") + @pytest.mark.usefixtures("update_rag_engine_config_basic_mock") def test_update_rag_engine_config_bad_input( - self, update_rag_engine_config_enterprise_mock + self, update_rag_engine_config_basic_mock ): rag_config = rag.update_rag_engine_config( rag_engine_config=test_rag_constants_preview.TEST_DEFAULT_RAG_ENGINE_CONFIG, ) - assert update_rag_engine_config_enterprise_mock.call_count == 1 + assert update_rag_engine_config_basic_mock.call_count == 1 rag_engine_config_eq( rag_config, - test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_ENTERPRISE, + test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_BASIC, ) @pytest.mark.usefixtures("get_rag_engine_basic_config_mock") @@ -1681,6 +1775,24 @@ def test_get_rag_engine_config_enterprise_success(self): rag_config, test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_ENTERPRISE ) + @pytest.mark.usefixtures("get_rag_engine_scaled_config_mock") + def test_get_rag_engine_config_scaled_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SCALED + ) + + @pytest.mark.usefixtures("get_rag_engine_unprovisioned_config_mock") + def test_get_rag_engine_config_unprovisioned_success(self): + rag_config = rag.get_rag_engine_config( + name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + ) + rag_engine_config_eq( + rag_config, test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED + ) + @pytest.mark.usefixtures("get_rag_engine_config_mock_exception") def test_get_rag_engine_config_failure(self): with pytest.raises(RuntimeError) as e: diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index f2f19570b5..9140c49fe4 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -64,11 +64,13 @@ RagVectorDbConfig, RankService, Ranking, + Scaled, SharePointSource, SharePointSources, SlackChannel, SlackChannelsSource, TransformationConfig, + Unprovisioned, VertexAiSearchConfig, VertexFeatureStore, VertexPredictionEndpoint, @@ -106,11 +108,13 @@ "Ranking", "RankService", "Retrieval", + "Scaled", "SharePointSource", "SharePointSources", "SlackChannel", "SlackChannelsSource", "TransformationConfig", + "Unprovisioned", "VertexAiSearchConfig", "VertexFeatureStore", "VertexPredictionEndpoint", diff --git a/vertexai/preview/rag/utils/_gapic_utils.py b/vertexai/preview/rag/utils/_gapic_utils.py index 5e99e5a59d..3501142c68 100644 --- a/vertexai/preview/rag/utils/_gapic_utils.py +++ b/vertexai/preview/rag/utils/_gapic_utils.py @@ -62,9 +62,11 @@ RagVectorDbConfig, Basic, Enterprise, + Scaled, SharePointSources, SlackChannelsSource, TransformationConfig, + Unprovisioned, VertexAiSearchConfig, VertexFeatureStore, VertexPredictionEndpoint, @@ -993,6 +995,10 @@ def convert_gapic_to_rag_engine_config( rag_managed_db_config.tier = Enterprise() elif response.rag_managed_db_config.__contains__("basic"): rag_managed_db_config.tier = Basic() + elif response.rag_managed_db_config.__contains__("unprovisioned"): + rag_managed_db_config.tier = Unprovisioned() + elif response.rag_managed_db_config.__contains__("scaled"): + rag_managed_db_config.tier = Scaled() else: raise ValueError("At least one of rag_managed_db_config must be set.") return RagEngineConfig( @@ -1011,13 +1017,19 @@ def convert_rag_engine_config_to_gapic( or rag_engine_config.rag_managed_db_config.tier is None ): rag_managed_db_config = GapicRagManagedDbConfig( - enterprise=GapicRagManagedDbConfig.Enterprise() + basic=GapicRagManagedDbConfig.Basic() ) else: if isinstance(rag_engine_config.rag_managed_db_config.tier, Enterprise): rag_managed_db_config.enterprise = GapicRagManagedDbConfig.Enterprise() elif isinstance(rag_engine_config.rag_managed_db_config.tier, Basic): rag_managed_db_config.basic = GapicRagManagedDbConfig.Basic() + elif isinstance(rag_engine_config.rag_managed_db_config.tier, Unprovisioned): + rag_managed_db_config.unprovisioned = ( + GapicRagManagedDbConfig.Unprovisioned() + ) + elif isinstance(rag_engine_config.rag_managed_db_config.tier, Scaled): + rag_managed_db_config.scaled = GapicRagManagedDbConfig.Scaled() return GapicRagEngineConfig( name=rag_engine_config.name, rag_managed_db_config=rag_managed_db_config, diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index bf8e8bffff..459c924e13 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -569,7 +569,16 @@ class Enterprise: autoscaling functionality. It is suitable for customers with large amounts of data or performance sensitive workloads. - NOTE: This is the default tier if not explicitly chosen. + NOTE: This is deprecated. Use Scaled tier instead. + """ + + +@dataclasses.dataclass +class Scaled: + """Scaled tier offers production grade performance along with + + autoscaling functionality. It is suitable for customers with large + amounts of data or performance sensitive workloads. """ @@ -581,6 +590,19 @@ class Basic: * Small data size. * Latency insensitive workload. * Only using RAG Engine with external vector DBs. + + NOTE: This is the default tier if not explicitly chosen. + """ + + +@dataclasses.dataclass +class Unprovisioned: + """Disables the RAG Engine service and deletes all your data held within + this service. This will halt the billing of the service. + + NOTE: Once deleted the data cannot be recovered. To start using + RAG Engine again, you will need to update the tier by calling the + UpdateRagEngineConfig API. """ @@ -591,10 +613,10 @@ class RagManagedDbConfig: The config of the RagManagedDb used by RagEngine. Attributes: - tier: The tier of the RagManagedDb. The default tier is Enterprise. + tier: The tier of the RagManagedDb. The default tier is Basic. """ - tier: Optional[Union[Enterprise, Basic]] = None + tier: Optional[Union[Enterprise, Basic, Scaled, Unprovisioned]] = None @dataclasses.dataclass @@ -605,7 +627,7 @@ class RagEngineConfig: name: Generated resource name for singleton resource. Format: ``projects/{project}/locations/{location}/ragEngineConfig`` rag_managed_db_config: The config of the RagManagedDb used by RagEngine. - The default tier is Enterprise. + The default tier is Basic. """ name: str diff --git a/vertexai/rag/__init__.py b/vertexai/rag/__init__.py index 3f30abbda3..9fc802e0ec 100644 --- a/vertexai/rag/__init__.py +++ b/vertexai/rag/__init__.py @@ -16,29 +16,30 @@ # from vertexai.rag.rag_data import ( + add_inline_citations_and_references, create_corpus, - update_corpus, - list_corpora, - get_corpus, delete_corpus, - upload_file, + delete_file, + get_corpus, + get_file, + get_rag_engine_config, import_files, import_files_async, - get_file, + list_corpora, list_files, - delete_file, - add_inline_citations_and_references, + update_corpus, + update_rag_engine_config, + upload_file, ) - from vertexai.rag.rag_retrieval import ( retrieval_query, ) - from vertexai.rag.rag_store import ( Retrieval, VertexRagStore, ) from vertexai.rag.utils.resources import ( + Basic, ChunkingConfig, Filter, JiraQuery, @@ -47,21 +48,25 @@ LlmParserConfig, LlmRanker, Pinecone, + RagCitedGenerationResponse, RagCorpus, RagEmbeddingModelConfig, + RagEngineConfig, RagFile, - RagCitedGenerationResponse, RagManagedDb, + RagManagedDbConfig, RagResource, RagRetrievalConfig, RagVectorDbConfig, - Ranking, RankService, + Ranking, + Scaled, SharePointSource, SharePointSources, SlackChannel, SlackChannelsSource, TransformationConfig, + Unprovisioned, VertexAiSearchConfig, VertexPredictionEndpoint, VertexVectorSearch, @@ -69,6 +74,7 @@ __all__ = ( + "Basic", "ChunkingConfig", "Filter", "JiraQuery", @@ -79,20 +85,24 @@ "Pinecone", "RagCorpus", "RagEmbeddingModelConfig", + "RagEngineConfig", "RagFile", "RagCitedGenerationResponse", "RagManagedDb", + "RagManagedDbConfig", "RagResource", "RagRetrievalConfig", "RagVectorDbConfig", "Ranking", "RankService", "Retrieval", + "Scaled", "SharePointSource", "SharePointSources", "SlackChannel", "SlackChannelsSource", "TransformationConfig", + "Unprovisioned", "VertexAiSearchConfig", "VertexRagStore", "VertexPredictionEndpoint", @@ -101,6 +111,7 @@ "delete_corpus", "delete_file", "get_corpus", + "get_rag_engine_config", "get_file", "import_files", "import_files_async", @@ -109,5 +120,6 @@ "retrieval_query", "upload_file", "update_corpus", + "update_rag_engine_config", "add_inline_citations_and_references", ) diff --git a/vertexai/rag/rag_data.py b/vertexai/rag/rag_data.py index 0c5fc3975e..87a7f1bbc5 100644 --- a/vertexai/rag/rag_data.py +++ b/vertexai/rag/rag_data.py @@ -28,12 +28,14 @@ DeleteRagCorpusRequest, DeleteRagFileRequest, GetRagCorpusRequest, + GetRagEngineConfigRequest, GetRagFileRequest, ImportRagFilesResponse, ListRagCorporaRequest, ListRagFilesRequest, RagCorpus as GapicRagCorpus, UpdateRagCorpusRequest, + UpdateRagEngineConfigRequest, ) from google.cloud.aiplatform_v1.services.vertex_rag_data_service.pagers import ( ListRagCorporaPager, @@ -53,6 +55,7 @@ LlmParserConfig, RagCitedGenerationResponse, RagCorpus, + RagEngineConfig, RagFile, RagVectorDbConfig, SharePointSources, @@ -897,6 +900,76 @@ def delete_file(name: str, corpus_name: Optional[str] = None) -> None: return None +def update_rag_engine_config( + rag_engine_config: RagEngineConfig, +) -> RagEngineConfig: + """Update RagEngineConfig. + + Example usage: + ``` + import vertexai + from vertexai import rag + vertexai.init(project="my-project") + rag_engine_config = rag.RagEngineConfig( + rag_managed_db_config=rag.RagManagedDbConfig( + rag_managed_db=rag.RagManagedDb( + db_basic_tier=rag.Basic(), + ), + ) + ), + ) + rag.update_rag_engine_config(rag_engine_config=rag_engine_config) + ``` + + Args: + rag_engine_config: The RagEngineConfig to update. + + Raises: + RuntimeError: Failed in RagEngineConfig update due to exception. + """ + gapic_rag_engine_config = _gapic_utils.convert_rag_engine_config_to_gapic( + rag_engine_config + ) + request = UpdateRagEngineConfigRequest( + rag_engine_config=gapic_rag_engine_config, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.update_rag_engine_config(request=request) + except Exception as e: + raise RuntimeError("Failed in RagEngineConfig update due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_engine_config(response.result(timeout=600)) + + +def get_rag_engine_config(name: str) -> RagEngineConfig: + """Get an existing RagEngineConfig. + + Example usage: + ``` + import vertexai + from vertexai import rag + vertexai.init(project="my-project") + rag_engine_config = rag.get_rag_engine_config( + name="projects/my-project/locations/us-central1/ragEngineConfig" + ) + ``` + Args: + name: The RagEngineConfig resource name pattern of the singleton resource. + + Returns: + RagEngineConfig. + Raises: + RuntimeError: Failed in getting the RagEngineConfig. + """ + request = GetRagEngineConfigRequest(name=name) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.get_rag_engine_config(request=request) + except Exception as e: + raise RuntimeError("Failed in getting the RagEngineConfig due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_engine_config(response) + + def add_inline_citations_and_references( original_text_str, grounding_supports, grounding_chunks ) -> RagCitedGenerationResponse: diff --git a/vertexai/rag/utils/_gapic_utils.py b/vertexai/rag/utils/_gapic_utils.py index 997131f7fc..fa8ce9bdfa 100644 --- a/vertexai/rag/utils/_gapic_utils.py +++ b/vertexai/rag/utils/_gapic_utils.py @@ -23,11 +23,13 @@ GoogleDriveSource, ImportRagFilesConfig, ImportRagFilesRequest, + RagEngineConfig as GapicRagEngineConfig, RagFileChunkingConfig, RagFileParsingConfig, RagFileTransformationConfig, RagCorpus as GapicRagCorpus, RagFile as GapicRagFile, + RagManagedDbConfig as GapicRagManagedDbConfig, SharePointSources as GapicSharePointSources, SlackSource as GapicSlackSource, JiraSource as GapicJiraSource, @@ -41,22 +43,27 @@ VertexRagClientWithOverride, ) from vertexai.rag.utils.resources import ( + Basic, + JiraSource, LayoutParserConfig, LlmParserConfig, Pinecone, + RagCitedGenerationResponse, RagCorpus, RagEmbeddingModelConfig, + RagEngineConfig, RagFile, RagManagedDb, + RagManagedDbConfig, RagVectorDbConfig, + Scaled, SharePointSources, SlackChannelsSource, TransformationConfig, - JiraSource, + Unprovisioned, VertexAiSearchConfig, VertexVectorSearch, VertexPredictionEndpoint, - RagCitedGenerationResponse, ) @@ -707,3 +714,51 @@ def set_vertex_ai_search_config( raise ValueError( "serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`" ) + + +def convert_gapic_to_rag_engine_config( + response: GapicRagEngineConfig, +) -> RagEngineConfig: + """Converts a GapicRagEngineConfig to a RagEngineConfig.""" + rag_managed_db_config = RagManagedDbConfig() + # If future fields are added with similar names, beware that __contains__ + # may match them. + if response.rag_managed_db_config.__contains__("basic"): + rag_managed_db_config.tier = Basic() + elif response.rag_managed_db_config.__contains__("unprovisioned"): + rag_managed_db_config.tier = Unprovisioned() + elif response.rag_managed_db_config.__contains__("scaled"): + rag_managed_db_config.tier = Scaled() + else: + raise ValueError("At least one of rag_managed_db_config must be set.") + return RagEngineConfig( + name=response.name, + rag_managed_db_config=rag_managed_db_config, + ) + + +def convert_rag_engine_config_to_gapic( + rag_engine_config: RagEngineConfig, +) -> GapicRagEngineConfig: + """Converts a RagEngineConfig to a GapicRagEngineConfig.""" + rag_managed_db_config = GapicRagManagedDbConfig() + if ( + rag_engine_config.rag_managed_db_config is None + or rag_engine_config.rag_managed_db_config.tier is None + ): + rag_managed_db_config = GapicRagManagedDbConfig( + basic=GapicRagManagedDbConfig.Basic() + ) + else: + if isinstance(rag_engine_config.rag_managed_db_config.tier, Basic): + rag_managed_db_config.basic = GapicRagManagedDbConfig.Basic() + elif isinstance(rag_engine_config.rag_managed_db_config.tier, Unprovisioned): + rag_managed_db_config.unprovisioned = ( + GapicRagManagedDbConfig.Unprovisioned() + ) + elif isinstance(rag_engine_config.rag_managed_db_config.tier, Scaled): + rag_managed_db_config.scaled = GapicRagManagedDbConfig.Scaled() + return GapicRagEngineConfig( + name=rag_engine_config.name, + rag_managed_db_config=rag_managed_db_config, + ) diff --git a/vertexai/rag/utils/resources.py b/vertexai/rag/utils/resources.py index 6f65583f49..248ca7d016 100644 --- a/vertexai/rag/utils/resources.py +++ b/vertexai/rag/utils/resources.py @@ -486,3 +486,64 @@ class RagCitedGenerationResponse: cited_text: str final_bibliography: str + + +@dataclasses.dataclass +class Scaled: + """Scaled tier offers production grade performance along with + + autoscaling functionality. It is suitable for customers with large + amounts of data or performance sensitive workloads. + """ + + +@dataclasses.dataclass +class Basic: + """Basic tier is a cost-effective and low compute tier suitable for the following cases: + + * Experimenting with RagManagedDb. + * Small data size. + * Latency insensitive workload. + * Only using RAG Engine with external vector DBs. + + NOTE: This is the default tier if not explicitly chosen. + """ + + +@dataclasses.dataclass +class Unprovisioned: + """Disables the RAG Engine service and deletes all your data held within + this service. This will halt the billing of the service. + + NOTE: Once deleted the data cannot be recovered. To start using + RAG Engine again, you will need to update the tier by calling the + UpdateRagEngineConfig API. + """ + + +@dataclasses.dataclass +class RagManagedDbConfig: + """RagManagedDbConfig. + + The config of the RagManagedDb used by RagEngine. + + Attributes: + tier: The tier of the RagManagedDb. The default tier is Basic. + """ + + tier: Optional[Union[Basic, Scaled, Unprovisioned]] = None + + +@dataclasses.dataclass +class RagEngineConfig: + """RagEngineConfig. + + Attributes: + name: Generated resource name for singleton resource. Format: + ``projects/{project}/locations/{location}/ragEngineConfig`` + rag_managed_db_config: The config of the RagManagedDb used by RagEngine. + The default tier is Basic. + """ + + name: str + rag_managed_db_config: Optional[RagManagedDbConfig] = None From c43de0ab174b22adf95f1b97c701ec2c4b4581fb Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Wed, 18 Jun 2025 13:34:20 -0700 Subject: [PATCH 22/24] feat: GenAI SDK client - add `show` method for `EvaluationResult` and `EvaluationDataset` classes in IPython environment PiperOrigin-RevId: 773042552 --- vertexai/_genai/__init__.py | 2 +- vertexai/_genai/_evals_metric_handlers.py | 68 +++- vertexai/_genai/_evals_visualization.py | 398 ++++++++++++++++++++++ vertexai/_genai/client.py | 2 +- vertexai/_genai/types.py | 68 +++- 5 files changed, 523 insertions(+), 15 deletions(-) create mode 100644 vertexai/_genai/_evals_visualization.py diff --git a/vertexai/_genai/__init__.py b/vertexai/_genai/__init__.py index 0f8415ea49..2c5aa504f5 100644 --- a/vertexai/_genai/__init__.py +++ b/vertexai/_genai/__init__.py @@ -29,7 +29,7 @@ def __getattr__(name): _evals = importlib.import_module(".evals", __package__) except ImportError as e: raise ImportError( - "The 'evals' module requires 'pandas' and 'tqdm'. " + "The 'evals' module requires additional dependencies. " "Please install them using pip install " "google-cloud-aiplatform[evaluation]" ) from e diff --git a/vertexai/_genai/_evals_metric_handlers.py b/vertexai/_genai/_evals_metric_handlers.py index 209426451b..693e0ef939 100644 --- a/vertexai/_genai/_evals_metric_handlers.py +++ b/vertexai/_genai/_evals_metric_handlers.py @@ -743,6 +743,55 @@ def get_handler_for_metric( raise ValueError(f"Unsupported metric: {metric.name}") +def calculate_win_rates(eval_result: types.EvaluationResult) -> dict[str, Any]: + """Calculates win/tie rates for comparison results.""" + if not eval_result.eval_case_results: + return {} + max_models = max( + ( + len(case.response_candidate_results) + for case in eval_result.eval_case_results + if case.response_candidate_results + ), + default=0, + ) + if max_models == 0: + return {} + stats = collections.defaultdict( + lambda: {"wins": [0] * max_models, "ties": 0, "valid_comparisons": 0} + ) + for case in eval_result.eval_case_results: + if not case.response_candidate_results: + continue + scores_by_metric = collections.defaultdict(list) + for idx, candidate in enumerate(case.response_candidate_results): + for name, res in ( + candidate.metric_results.items() if candidate.metric_results else {} + ): + if res.score is not None: + scores_by_metric[name].append({"score": res.score, "cand_idx": idx}) + for name, scores in scores_by_metric.items(): + if not scores: + continue + stats[name]["valid_comparisons"] += 1 + max_score = max(s["score"] for s in scores) + winners = [s["cand_idx"] for s in scores if s["score"] == max_score] + if len(winners) == 1: + stats[name]["wins"][winners[0]] += 1 + else: + stats[name]["ties"] += 1 + win_rates = {} + for name, metric_stats in stats.items(): + if metric_stats["valid_comparisons"] > 0: + win_rates[name] = { + "win_rates": [ + w / metric_stats["valid_comparisons"] for w in metric_stats["wins"] + ], + "tie_rate": metric_stats["ties"] / metric_stats["valid_comparisons"], + } + return win_rates + + def _aggregate_metric_results( metric_handlers: list[MetricHandler], eval_case_results: list[types.EvalCaseResult], @@ -1001,10 +1050,6 @@ def compute_metrics_and_aggregate( ) final_eval_case_results.append(eval_case_result) - aggregated_metric_results = _aggregate_metric_results( - metric_handlers, final_eval_case_results - ) - if submission_errors: logger.warning("Encountered %d submission errors.", len(submission_errors)) logger.warning("Submission errors: %s", submission_errors) @@ -1012,7 +1057,20 @@ def compute_metrics_and_aggregate( logger.warning("Encountered %d execution errors.", len(execution_errors)) logger.warning("Execution errors: %s", execution_errors) - return types.EvaluationResult( + aggregated_metric_results = _aggregate_metric_results( + metric_handlers, final_eval_case_results + ) + eval_result = types.EvaluationResult( eval_case_results=final_eval_case_results, summary_metrics=aggregated_metric_results, ) + if evaluation_run_config.num_response_candidates > 1: + try: + eval_result.win_rates = calculate_win_rates(eval_result) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Error calculating win rates: %s", + e, + exc_info=True, + ) + return eval_result diff --git a/vertexai/_genai/_evals_visualization.py b/vertexai/_genai/_evals_visualization.py new file mode 100644 index 0000000000..8727a2cbb2 --- /dev/null +++ b/vertexai/_genai/_evals_visualization.py @@ -0,0 +1,398 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Visualization utilities for GenAI Evaluation SDK.""" + +import json +import logging +from typing import Optional + +from pydantic import errors +import pandas as pd + +from . import types + +logger = logging.getLogger(__name__) + + +def _is_ipython_env() -> bool: + """Checks if the code is running in an IPython environment.""" + try: + from IPython import get_ipython + + return get_ipython() is not None + except ImportError: + return False + + +def _preprocess_df_for_json(df: Optional[pd.DataFrame]) -> Optional[pd.DataFrame]: + """Prepares a DataFrame for JSON serialization by converting complex objects to strings.""" + if df is None: + return None + df_copy = df.copy() + + for col in df_copy.columns: + if ( + df_copy[col].dtype == "object" + or df_copy[col].apply(lambda x: isinstance(x, (dict, list))).any() + ): + + def stringify_cell(cell): + if pd.isna(cell): + return None + if isinstance(cell, (dict, list)): + try: + return json.dumps(cell, ensure_ascii=False) + except TypeError: + return str(cell) + elif not isinstance(cell, (str, int, float, bool)): + return str(cell) + return cell + + df_copy[col] = df_copy[col].apply(stringify_cell) + return df_copy + + +def _get_evaluation_html(eval_result_json: str) -> str: + """Returns a self-contained HTML for single evaluation visualization.""" + return f""" + + + + + Evaluation Report + + + + + +
+

Evaluation Report

+
+
+
+ + + +""" + + +def _get_comparison_html(eval_result_json: str) -> str: + """Returns a self-contained HTML for a side-by-side eval comparison.""" + return f""" + + + + + Eval Comparison Report + + + + + +
+

Eval Comparison Report

+
+
+
+ + + +""" + + +def _get_inference_html(dataframe_json: str) -> str: + """Returns a self-contained HTML for displaying inference results.""" + return f""" + + + + + Inference Results + + + + + +
+

Inference Results

+
+
+ + + +""" + + +def display_evaluation_result( + eval_result_obj: types.EvaluationResult, + candidate_names: Optional[list[str]] = None, +) -> None: + """Displays evaluation result in an IPython environment.""" + if not _is_ipython_env(): + logger.warning("Skipping display: not in an IPython environment.") + return + else: + from IPython import display + + try: + result_dump = eval_result_obj.model_dump( + mode="json", exclude_none=True, exclude={"evaluation_dataset"} + ) + except errors.PydanticSerializationError as e: + logger.error( + "Serialization Error: %s\nCould not display the evaluation " + "result due to a data serialization issue. Please check the " + "content of the EvaluationResult object.", + e, + ) + return + except Exception as e: + logger.error("Failed to serialize EvaluationResult: %s", e, exc_info=True) + raise + + input_dataset_list = eval_result_obj.evaluation_dataset + is_comparison = input_dataset_list and len(input_dataset_list) > 1 + + metadata_payload = result_dump.get("metadata", {}) + metadata_payload["candidate_names"] = candidate_names or metadata_payload.get( + "candidate_names" + ) + + if is_comparison: + if ( + input_dataset_list + and input_dataset_list[0] + and input_dataset_list[0].eval_dataset_df is not None + ): + metadata_payload["dataset"] = _preprocess_df_for_json( + input_dataset_list[0].eval_dataset_df + ).to_dict(orient="records") + + if "eval_case_results" in result_dump: + for case_res in result_dump["eval_case_results"]: + for resp_idx, cand_res in enumerate( + case_res.get("response_candidate_results", []) + ): + if ( + resp_idx < len(input_dataset_list) + and input_dataset_list[resp_idx].eval_dataset_df is not None + ): + df = _preprocess_df_for_json( + input_dataset_list[resp_idx].eval_dataset_df + ) + case_idx = case_res.get("eval_case_index") + if ( + df is not None + and case_idx is not None + and case_idx < len(df) + ): + cand_res["response_text"] = df.iloc[case_idx].get( + "response" + ) + + win_rates = eval_result_obj.win_rates if eval_result_obj.win_rates else {} + if "summary_metrics" in result_dump: + for summary in result_dump["summary_metrics"]: + if summary.get("metric_name") in win_rates: + summary.update(win_rates[summary["metric_name"]]) + + result_dump["metadata"] = metadata_payload + html_content = _get_comparison_html(json.dumps(result_dump)) + else: + single_dataset = input_dataset_list[0] if input_dataset_list else None + + if ( + single_dataset is not None + and isinstance(single_dataset, types.EvaluationDataset) + and single_dataset.eval_dataset_df is not None + ): + processed_df = _preprocess_df_for_json(single_dataset.eval_dataset_df) + metadata_payload["dataset"] = processed_df.to_dict(orient="records") + if "eval_case_results" in result_dump and processed_df is not None: + for case_res in result_dump["eval_case_results"]: + case_idx = case_res.get("eval_case_index") + if ( + case_idx is not None + and case_idx < len(processed_df) + and case_res.get("response_candidate_results") + ): + case_res["response_candidate_results"][0][ + "response_text" + ] = processed_df.iloc[case_idx].get("response") + + result_dump["metadata"] = metadata_payload + html_content = _get_evaluation_html(json.dumps(result_dump)) + + display.display(display.HTML(html_content)) + + +def display_evaluation_dataset(eval_dataset_obj: types.EvaluationDataset) -> None: + """Displays an evaluation dataset in an IPython environment.""" + if not _is_ipython_env(): + logger.warning("Skipping display: not in an IPython environment.") + return + else: + from IPython import display + + if ( + eval_dataset_obj.eval_dataset_df is None + or eval_dataset_obj.eval_dataset_df.empty + ): + logger.warning("No inference data to display.") + return + + processed_df = _preprocess_df_for_json(eval_dataset_obj.eval_dataset_df) + dataframe_json_string = json.dumps(processed_df.to_json(orient="records")) + html_content = _get_inference_html(dataframe_json_string) + display.display(display.HTML(html_content)) diff --git a/vertexai/_genai/client.py b/vertexai/_genai/client.py index a0cb0052df..dc6f22329e 100644 --- a/vertexai/_genai/client.py +++ b/vertexai/_genai/client.py @@ -118,7 +118,7 @@ def evals(self): self._evals = importlib.import_module(".evals", __package__) except ImportError as e: raise ImportError( - "The 'evals' module requires 'pandas' and 'tqdm'. " + "The 'evals' module requires additional dependencies. " "Please install them using pip install " "google-cloud-aiplatform[evaluation]" ) from e diff --git a/vertexai/_genai/types.py b/vertexai/_genai/types.py index e87965e339..3d224c982a 100644 --- a/vertexai/_genai/types.py +++ b/vertexai/_genai/types.py @@ -21,7 +21,17 @@ import logging import re import typing -from typing import Any, Callable, ClassVar, Literal, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + ClassVar, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, +) from google.genai import _common from google.genai import types as genai_types from pydantic import ( @@ -2063,6 +2073,12 @@ class EvaluationDataset(_common.BaseModel): description="""The BigQuery source for the evaluation dataset.""", ) + def show(self) -> None: + """Shows the evaluation dataset.""" + from . import _evals_visualization + + _evals_visualization.display_evaluation_dataset(self) + class EvaluationDatasetDict(TypedDict, total=False): """The dataset used for evaluation.""" @@ -2887,6 +2903,31 @@ class EvalRunInferenceConfigDict(TypedDict, total=False): EvalRunInferenceConfigOrDict = Union[EvalRunInferenceConfig, EvalRunInferenceConfigDict] +class WinRateStats(_common.BaseModel): + """Statistics for win rates for a single metric.""" + + win_rates: Optional[list[float]] = Field( + default=None, + description="""Win rates for the metric, one for each candidate.""", + ) + tie_rate: Optional[float] = Field( + default=None, description="""Tie rate for the metric.""" + ) + + +class WinRateStatsDict(TypedDict, total=False): + """Statistics for win rates for a single metric.""" + + win_rates: Optional[list[float]] + """Win rates for the metric, one for each candidate.""" + + tie_rate: Optional[float] + """Tie rate for the metric.""" + + +WinRateStatsOrDict = Union[WinRateStats, WinRateStatsDict] + + class EvalCaseMetricResult(_common.BaseModel): """Evaluation result for a single evaluation case for a single metric.""" @@ -3009,10 +3050,6 @@ class AggregatedMetricResult(_common.BaseModel): stdev_score: Optional[float] = Field( default=None, description="""Standard deviation of the metric.""" ) - win_rate: Optional[dict[str, float]] = Field( - default=None, - description="""A dictionary of win rates for each response.""", - ) # Allow extra fields to support custom aggregation stats. model_config = ConfigDict(extra="allow") @@ -3039,9 +3076,6 @@ class AggregatedMetricResultDict(TypedDict, total=False): stdev_score: Optional[float] """Standard deviation of the metric.""" - win_rate: Optional[dict[str, float]] - """A dictionary of win rates for each response.""" - AggregatedMetricResultOrDict = Union[AggregatedMetricResult, AggregatedMetricResultDict] @@ -3097,6 +3131,10 @@ class EvaluationResult(_common.BaseModel): default=None, description="""A list of summary-level evaluation results for each metric.""", ) + win_rates: Optional[dict[str, WinRateStats]] = Field( + default=None, + description="""A dictionary of win rates for each metric, only populated for multi-response evaluation runs.""", + ) evaluation_dataset: Optional[list[EvaluationDataset]] = Field( default=None, description="""The input evaluation dataset(s) for the evaluation run.""", @@ -3105,6 +3143,17 @@ class EvaluationResult(_common.BaseModel): default=None, description="""Metadata for the evaluation run.""" ) + def show(self, candidate_names: Optional[List[str]] = None) -> None: + """Shows the evaluation result. + + Args: + candidate_names: list of names for the evaluated candidates, used in + comparison reports. + """ + from . import _evals_visualization + + _evals_visualization.display_evaluation_result(self, candidate_names) + class EvaluationResultDict(TypedDict, total=False): """Result of an evaluation run for an evaluation dataset.""" @@ -3115,6 +3164,9 @@ class EvaluationResultDict(TypedDict, total=False): summary_metrics: Optional[list[AggregatedMetricResultDict]] """A list of summary-level evaluation results for each metric.""" + win_rates: Optional[dict[str, WinRateStatsDict]] + """A dictionary of win rates for each metric, only populated for multi-response evaluation runs.""" + evaluation_dataset: Optional[list[EvaluationDatasetDict]] """The input evaluation dataset(s) for the evaluation run.""" From 0a26a2080917c2c38f6c35fb22b4dc779d3d0f9a Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Wed, 18 Jun 2025 17:33:26 -0700 Subject: [PATCH 23/24] chore: remove prompt_optimizer property from GenAI client PiperOrigin-RevId: 773126369 --- .../vertexai/genai/test_prompt_optimizer.py | 55 ++++++++++--------- vertexai/_genai/client.py | 11 ---- 2 files changed, 28 insertions(+), 38 deletions(-) diff --git a/tests/unit/vertexai/genai/test_prompt_optimizer.py b/tests/unit/vertexai/genai/test_prompt_optimizer.py index 73f82736d4..e08c4fff8a 100644 --- a/tests/unit/vertexai/genai/test_prompt_optimizer.py +++ b/tests/unit/vertexai/genai/test_prompt_optimizer.py @@ -28,8 +28,9 @@ from google.cloud.aiplatform.compat.types import ( job_state as gca_job_state_compat, ) -from google.cloud.aiplatform.utils import gcs_utils -from google.genai import client + +# from google.cloud.aiplatform.utils import gcs_utils +# from google.genai import client import pytest @@ -76,29 +77,29 @@ def setup_method(self): location=_TEST_LOCATION, ) - @pytest.mark.usefixtures("google_auth_mock") - def test_prompt_optimizer_client(self): - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) - assert test_client is not None - assert test_client._api_client.vertexai - assert test_client._api_client.project == _TEST_PROJECT - assert test_client._api_client.location == _TEST_LOCATION + # @pytest.mark.usefixtures("google_auth_mock") + # def test_prompt_optimizer_client(self): + # test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + # assert test_client is not None + # assert test_client._api_client.vertexai + # assert test_client._api_client.project == _TEST_PROJECT + # assert test_client._api_client.location == _TEST_LOCATION - @mock.patch.object(client.Client, "_get_api_client") - @mock.patch.object( - gcs_utils.resource_manager_utils, "get_project_number", return_value=12345 - ) - def test_prompt_optimizer_optimize( - self, mock_get_project_number, mock_client, mock_create_custom_job - ): - """Test that prompt_optimizer.optimize method creates a custom job.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) - test_client.prompt_optimizer.optimize( - method="vapo", - config={ - "config_path": "gs://ssusie-vapo-sdk-test/config.json", - "wait_for_completion": False, - }, - ) - mock_create_custom_job.assert_called_once() - mock_get_project_number.assert_called_once() + # @mock.patch.object(client.Client, "_get_api_client") + # @mock.patch.object( + # gcs_utils.resource_manager_utils, "get_project_number", return_value=12345 + # ) + # def test_prompt_optimizer_optimize( + # self, mock_get_project_number, mock_client, mock_create_custom_job + # ): + # """Test that prompt_optimizer.optimize method creates a custom job.""" + # test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + # test_client.prompt_optimizer.optimize( + # method="vapo", + # config={ + # "config_path": "gs://ssusie-vapo-sdk-test/config.json", + # "wait_for_completion": False, + # }, + # ) + # mock_create_custom_job.assert_called_once() + # mock_get_project_number.assert_called_once() diff --git a/vertexai/_genai/client.py b/vertexai/_genai/client.py index dc6f22329e..3d06a34064 100644 --- a/vertexai/_genai/client.py +++ b/vertexai/_genai/client.py @@ -124,17 +124,6 @@ def evals(self): ) from e return self._evals.Evals(self._api_client) - @property - @_common.experimental_warning( - "The Vertex SDK GenAI prompt optimizer module is experimental, " - "and may change in future versions." - ) - def prompt_optimizer(self): - self._prompt_optimizer = importlib.import_module( - ".prompt_optimizer", __package__ - ) - return self._prompt_optimizer.PromptOptimizer(self._api_client) - @property @_common.experimental_warning( "The Vertex SDK GenAI async client is experimental, " From 73873478057ffa47c329c4369fca60e9dea38260 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Fri, 20 Jun 2025 09:23:40 -0700 Subject: [PATCH 24/24] chore(main): release 1.98.0 (#5416) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- .release-please-manifest.json | 2 +- CHANGELOG.md | 35 +++++++++++++++++++ google/cloud/aiplatform/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1/gapic_version.py | 2 +- .../v1/schema/predict/params/gapic_version.py | 2 +- .../schema/predict/params_v1/gapic_version.py | 2 +- .../predict/prediction/gapic_version.py | 2 +- .../predict/prediction_v1/gapic_version.py | 2 +- .../trainingjob/definition/gapic_version.py | 2 +- .../definition_v1/gapic_version.py | 2 +- .../schema/predict/instance/gapic_version.py | 2 +- .../predict/instance_v1beta1/gapic_version.py | 2 +- .../schema/predict/params/gapic_version.py | 2 +- .../predict/params_v1beta1/gapic_version.py | 2 +- .../predict/prediction/gapic_version.py | 2 +- .../prediction_v1beta1/gapic_version.py | 2 +- .../trainingjob/definition/gapic_version.py | 2 +- .../definition_v1beta1/gapic_version.py | 2 +- google/cloud/aiplatform/version.py | 2 +- google/cloud/aiplatform_v1/gapic_version.py | 2 +- .../cloud/aiplatform_v1beta1/gapic_version.py | 2 +- pypi/_vertex_ai_placeholder/version.py | 2 +- ...t_metadata_google.cloud.aiplatform.v1.json | 2 +- ...adata_google.cloud.aiplatform.v1beta1.json | 2 +- 25 files changed, 59 insertions(+), 24 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index ddfefcab9f..166b986ce3 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.97.0" + ".": "1.98.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index dca632818e..0e24845fd3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,40 @@ # Changelog +## [1.98.0](https://github.com/googleapis/python-aiplatform/compare/v1.97.0...v1.98.0) (2025-06-19) + + +### Features + +* Add DnsPeeringConfig in service_networking.proto ([c5bb99b](https://github.com/googleapis/python-aiplatform/commit/c5bb99b80dbbc76ababdba1228154717370eb5dd)) +* Add DnsPeeringConfig in service_networking.proto ([c5bb99b](https://github.com/googleapis/python-aiplatform/commit/c5bb99b80dbbc76ababdba1228154717370eb5dd)) +* Add EncryptionSpec field for RagCorpus CMEK feature to v1 ([9b48d24](https://github.com/googleapis/python-aiplatform/commit/9b48d24ab90c57d4a49b3adf22a79cffbe065351)) +* Add PSC interface config support for Custom Training Jobs ([267b53d](https://github.com/googleapis/python-aiplatform/commit/267b53d4a7db87cdf70181f76adb5c6980a2136a)) +* Add RagEngineConfig update/get APIs to v1 ([c5bb99b](https://github.com/googleapis/python-aiplatform/commit/c5bb99b80dbbc76ababdba1228154717370eb5dd)) +* Add Scaled tier for RagEngineConfig to v1beta, equivalent to Enterprise ([c5bb99b](https://github.com/googleapis/python-aiplatform/commit/c5bb99b80dbbc76ababdba1228154717370eb5dd)) +* Added `autoscaling_target_request_count_per_minute` to model deployment on Endpoint and Model classes ([4df909c](https://github.com/googleapis/python-aiplatform/commit/4df909cdfddf071c4b87a7d1dabed56437d528a2)) +* Adding VAPO Prompt Optimizer (PO-data) to the genai SDK. ([701b8d4](https://github.com/googleapis/python-aiplatform/commit/701b8d40ba7b1265051a8b6a507e8a6b8e242a54)) +* Enable asia-south2 ([a1f4205](https://github.com/googleapis/python-aiplatform/commit/a1f420582908bc3d9a3201d36bf8d075758d4644)) +* Enable Vertex Multimodal Dataset as input to supervised fine-tuning. ([959d798](https://github.com/googleapis/python-aiplatform/commit/959d79869468c1fa66b7691eb8c4071a5af3eec4)) +* Export global quota configs in preview sdk ([7f964d5](https://github.com/googleapis/python-aiplatform/commit/7f964d5625bea84fada767efc32661db34473a80)) +* GenAI SDK client - add `show` method for `EvaluationResult` and `EvaluationDataset` classes in IPython environment ([c43de0a](https://github.com/googleapis/python-aiplatform/commit/c43de0ab174b22adf95f1b97c701ec2c4b4581fb)) +* Introduce RagFileMetadataConfig for importing metadata to Rag ([9b48d24](https://github.com/googleapis/python-aiplatform/commit/9b48d24ab90c57d4a49b3adf22a79cffbe065351)) +* RAG - Add Basic, Scaled and Unprovisioned tier in v1. ([726d3a2](https://github.com/googleapis/python-aiplatform/commit/726d3a26c81b95163416af5b3da2367086317573)) +* RAG - Add Scaled and Unprovisioned tier in preview. ([726d3a2](https://github.com/googleapis/python-aiplatform/commit/726d3a26c81b95163416af5b3da2367086317573)) +* RAG - Implement v1 `get_rag_engine_config` in `rag_data.py` ([726d3a2](https://github.com/googleapis/python-aiplatform/commit/726d3a26c81b95163416af5b3da2367086317573)) +* RAG - Implement v1 `update_rag_engine_config` in `rag_data.py` ([726d3a2](https://github.com/googleapis/python-aiplatform/commit/726d3a26c81b95163416af5b3da2367086317573)) +* Update v1 `create_corpus` to accept `encryption_spec` in `rag_data.py` ([865a68c](https://github.com/googleapis/python-aiplatform/commit/865a68c1273aa4e4e946a203bf226b80a723523f)) + + +### Bug Fixes + +* Update supported python version for create_reasoning_engine ([0059c01](https://github.com/googleapis/python-aiplatform/commit/0059c01b7395fc93be8d9214c938299678f67d3e)) +* Use none check to avoid 30s delay in agent run. ([84895b6](https://github.com/googleapis/python-aiplatform/commit/84895b6c6cd8d898d8472f0a1ace12a8b420717b)) + + +### Documentation + +* Add GenAI client examples to readme ([f1e17a6](https://github.com/googleapis/python-aiplatform/commit/f1e17a6b35fb31b7a5eb589a132df5df0ad7e3e4)) + ## [1.97.0](https://github.com/googleapis/python-aiplatform/compare/v1.96.0...v1.97.0) (2025-06-11) diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index 7e6265b87a..4ae710b3b4 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.97.0" +__version__ = "1.98.0" diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index 4fb60fd4ef..1a1b7dfaef 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.97.0" # {x-release-please-version} +__version__ = "1.98.0" # {x-release-please-version} diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py index 2d87273428..ebd0774020 100644 --- a/pypi/_vertex_ai_placeholder/version.py +++ b/pypi/_vertex_ai_placeholder/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.97.0" +__version__ = "1.98.0" diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index f8471f4aef..524f1eb18e 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "0.1.0" + "version": "1.98.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index 418c8046bc..a8587832b3 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "0.1.0" + "version": "1.98.0" }, "snippets": [ {