|
13 | 13 | # limitations under the License.
|
14 | 14 | #
|
15 | 15 | # pylint: disable=protected-access,bad-continuation
|
16 |
| -import copy |
| 16 | + |
17 | 17 | import importlib
|
18 | 18 | from unittest import mock
|
19 | 19 |
|
20 |
| -from google.cloud import aiplatform |
21 | 20 | import vertexai
|
22 |
| -from google.cloud.aiplatform import initializer as aiplatform_initializer |
23 |
| -from google.cloud.aiplatform.compat.services import job_service_client |
24 |
| -from google.cloud.aiplatform.compat.types import ( |
25 |
| - custom_job as gca_custom_job_compat, |
26 |
| -) |
27 |
| -from google.cloud.aiplatform.compat.types import io as gca_io_compat |
28 |
| -from google.cloud.aiplatform.compat.types import ( |
29 |
| - job_state as gca_job_state_compat, |
30 |
| -) |
31 |
| - |
32 |
| -# from google.cloud.aiplatform.utils import gcs_utils |
33 |
| -# from google.genai import client |
| 21 | +from vertexai._genai import prompt_optimizer |
| 22 | +from vertexai._genai import types |
| 23 | +from google.genai import client |
34 | 24 | import pytest
|
35 | 25 |
|
36 | 26 |
|
|
42 | 32 | _TEST_DISPLAY_NAME = f"{_TEST_PARENT}/customJobs/12345"
|
43 | 33 | _TEST_BASE_OUTPUT_DIR = "gs://test_bucket/test_base_output_dir"
|
44 | 34 |
|
45 |
| -_TEST_CUSTOM_JOB_PROTO = gca_custom_job_compat.CustomJob( |
46 |
| - display_name=_TEST_DISPLAY_NAME, |
47 |
| - job_spec={ |
48 |
| - "base_output_directory": gca_io_compat.GcsDestination( |
49 |
| - output_uri_prefix=_TEST_BASE_OUTPUT_DIR |
50 |
| - ), |
51 |
| - }, |
52 |
| - labels={"trained_by_vertex_ai": "true"}, |
53 |
| -) |
54 |
| - |
55 |
| - |
56 |
| -@pytest.fixture |
57 |
| -def mock_create_custom_job(): |
58 |
| - with mock.patch.object( |
59 |
| - job_service_client.JobServiceClient, "create_custom_job" |
60 |
| - ) as create_custom_job_mock: |
61 |
| - custom_job_proto = copy.deepcopy(_TEST_CUSTOM_JOB_PROTO) |
62 |
| - custom_job_proto.name = _TEST_DISPLAY_NAME |
63 |
| - custom_job_proto.state = gca_job_state_compat.JobState.JOB_STATE_PENDING |
64 |
| - create_custom_job_mock.return_value = custom_job_proto |
65 |
| - yield create_custom_job_mock |
66 |
| - |
67 | 35 |
|
68 | 36 | class TestPromptOptimizer:
|
69 | 37 | """Unit tests for the Prompt Optimizer client."""
|
70 | 38 |
|
71 | 39 | def setup_method(self):
|
72 |
| - importlib.reload(aiplatform_initializer) |
73 |
| - importlib.reload(aiplatform) |
74 | 40 | importlib.reload(vertexai)
|
75 | 41 | vertexai.init(
|
76 | 42 | project=_TEST_PROJECT,
|
77 | 43 | location=_TEST_LOCATION,
|
78 | 44 | )
|
79 | 45 |
|
80 |
| - # @pytest.mark.usefixtures("google_auth_mock") |
81 |
| - # def test_prompt_optimizer_client(self): |
82 |
| - # test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) |
83 |
| - # assert test_client is not None |
84 |
| - # assert test_client._api_client.vertexai |
85 |
| - # assert test_client._api_client.project == _TEST_PROJECT |
86 |
| - # assert test_client._api_client.location == _TEST_LOCATION |
| 46 | + @pytest.mark.usefixtures("google_auth_mock") |
| 47 | + def test_prompt_optimizer_client(self): |
| 48 | + test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 49 | + assert test_client.prompt_optimizer is not None |
| 50 | + |
| 51 | + @mock.patch.object(client.Client, "_get_api_client") |
| 52 | + @mock.patch.object(prompt_optimizer.PromptOptimizer, "_create_custom_job_resource") |
| 53 | + def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client): |
| 54 | + """Test that prompt_optimizer.optimize method creates a custom job.""" |
| 55 | + test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) |
| 56 | +
8000
test_client.prompt_optimizer.optimize( |
| 57 | + method="vapo", |
| 58 | + config=types.PromptOptimizerVAPOConfig( |
| 59 | + config_path="gs://ssusie-vapo-sdk-test/config.json", |
| 60 | + wait_for_completion=False, |
| 61 | + service_account="test-service-account", |
| 62 | + ), |
| 63 | + ) |
| 64 | + mock_client.assert_called_once() |
| 65 | + mock_custom_job.assert_called_once() |
87 | 66 |
|
88 |
| - # @mock.patch.object(client.Client, "_get_api_client") |
89 |
| - # @mock.patch.object( |
90 |
| - # gcs_utils.resource_manager_utils, "get_project_number", return_value=12345 |
91 |
| - # ) |
92 |
| - # def test_prompt_optimizer_optimize( |
93 |
| - # self, mock_get_project_number, mock_client, mock_create_custom_job |
94 |
| - # ): |
95 |
| - # """Test that prompt_optimizer.optimize method creates a custom job.""" |
96 |
| - # test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) |
97 |
| - # test_client.prompt_optimizer.optimize( |
98 |
| - # method="vapo", |
99 |
| - # config={ |
100 |
| - # "config_path": "gs://ssusie-vapo-sdk-test/config.json", |
101 |
| - # "wait_for_completion": False, |
102 |
| - # }, |
103 |
| - # ) |
104 |
| - # mock_create_custom_job.assert_called_once() |
105 |
| - # mock_get_project_number.assert_called_once() |
| 67 | + # TODO(b/415060797): add more tests for prompt_optimizer.optimize |
0 commit comments