8000 feat: GenAI SDK client - Add Vertex AI Prompt Optimizer to the Gen AI… · googleapis/python-aiplatform@a83c7b2 · GitHub
[go: up one dir, main page]

Skip to content

Commit a83c7b2

Browse files
sararobcopybara-github
authored andcommitted
feat: GenAI SDK client - Add Vertex AI Prompt Optimizer to the Gen AI SDK (experimental)
PiperOrigin-RevId: 773036828
1 parent 0bb0588 commit a83c7b2

File tree

6 files changed

+1943
-304
lines changed

6 files changed

+1943
-304
lines changed

tests/unit/vertexai/genai/test_prompt_optimizer.py

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,14 @@
1313
# limitations under the License.
1414
#
1515
# pylint: disable=protected-access,bad-continuation
16-
import copy
16+
1717
import importlib
1818
from unittest import mock
1919

20-
from google.cloud import aiplatform
2120
import vertexai
22-
from google.cloud.aiplatform import initializer as aiplatform_initializer
23-
from google.cloud.aiplatform.compat.services import job_service_client
24-
from google.cloud.aiplatform.compat.types import (
25-
custom_job as gca_custom_job_compat,
26-
)
27-
from google.cloud.aiplatform.compat.types import io as gca_io_compat
28-
from google.cloud.aiplatform.compat.types import (
29-
job_state as gca_job_state_compat,
30-
)
31-
32-
# from google.cloud.aiplatform.utils import gcs_utils
33-
# from google.genai import client
21+
from vertexai._genai import prompt_optimizer
22+
from vertexai._genai import types
23+
from google.genai import client
3424
import pytest
3525

3626

@@ -42,64 +32,36 @@
4232
_TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345"
4333
_TEST_BASE_OUTPUT_DIR = "gs://test_bucket/test_base_output_dir"
4434

45-
_TEST_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob(
46-
display_name=_TEST_DISPLAY_NAME,
47-
job_spec={
48-
"base_output_directory": gca_io_compat.GcsDestination(
49-
output_uri_prefix=_TEST_BASE_OUTPUT_DIR
50-
),
51-
},
52-
labels={"trained_by_vertex_ai": "true"},
53-
)
54-
55-
56-
@pytest.fixture
57-
def mock_create_custom_job():
58-
with mock.patch.object(
59-
job_service_client.JobServiceClient, "create_custom_job"
60-
) as create_custom_job_mock:
61-
custom_job_proto = copy.deepcopy(_TEST_CUSTOM_JOB_PROTO)
62-
custom_job_proto.name = _TEST_DISPLAY_NAME
63-
custom_job_proto.state = gca_job_state_compat.JobState.JOB_STATE_PENDING
64-
create_custom_job_mock.return_value = custom_job_proto
65-
yield create_custom_job_mock
66-
6735

6836
class TestPromptOptimizer:
6937
"""Unit tests for the Prompt Optimizer client."""
7038

7139
def setup_method(self):
72-
importlib.reload(aiplatform_initializer)
73-
importlib.reload(aiplatform)
7440
importlib.reload(vertexai)
7541
vertexai.init(
7642
project=_TEST_PROJECT,
7743
location=_TEST_LOCATION,
7844
)
7945

80-
# @pytest.mark.usefixtures("google_auth_mock")
81-
# def test_prompt_optimizer_client(self):
82-
# test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
83-
# assert test_client is not None
84-
# assert test_client._api_client.vertexai
85-
# assert test_client._api_client.project == _TEST_PROJECT
86-
# assert test_client._api_client.location == _TEST_LOCATION
46+
@pytest.mark.usefixtures("google_auth_mock")
47+
def test_prompt_optimizer_client(self):
48+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
49+
assert test_client.prompt_optimizer is not None
50+
51+
@mock.patch.object(client.Client, "_get_api_client")
52+
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_create_custom_job_resource")
53+
def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client):
54+
"""Test that prompt_optimizer.optimize method creates a custom job."""
55+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
56+
8000 test_client.prompt_optimizer.optimize(
57+
method="vapo",
58+
config=types.PromptOptimizerVAPOConfig(
59+
config_path="gs://ssusie-vapo-sdk-test/config.json",
60+
wait_for_completion=False,
61+
service_account="test-service-account",
62+
),
63+
)
64+
mock_client.assert_called_once()
65+
mock_custom_job.assert_called_once()
8766

88-
# @mock.patch.object(client.Client, "_get_api_client")
89-
# @mock.patch.object(
90-
# gcs_utils.resource_manager_utils, "get_project_number", return_value=12345
91-
# )
92-
# def test_prompt_optimizer_optimize(
93-
# self, mock_get_project_number, mock_client, mock_create_custom_job
94-
# ):
95-
# """Test that prompt_optimizer.optimize method creates a custom job."""
96-
# test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
97-
# test_client.prompt_optimizer.optimize(
98-
# method="vapo",
99-
# config={
100-
# "config_path": "gs://ssusie-vapo-sdk-test/config.json",
101-
# "wait_for_completion": False,
102-
# },
103-
# )
104-
# mock_create_custom_job.assert_called_once()
105-
# mock_get_project_number.assert_called_once()
67+
# TODO(b/415060797): add more tests for prompt_optimizer.optimize

0 commit comments

Comments
 (0)
0