8000 feat: RAG - Add Scaled and Unprovisioned tier in preview. · googleapis/python-aiplatform@4e708e5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4e708e5

Browse files
darshanmehta17copybara-github
authored andcommitted
feat: RAG - Add Scaled and Unprovisioned tier in preview.
feat: RAG - Implement v1 `update_rag_engine_config` in `rag_data.py` feat: RAG - Implement v1 `get_rag_engine_config` in `rag_data.py` feat: RAG - Add Basic, Scaled and Unprovisioned tier in v1. PiperOrigin-RevId: 772549031
1 parent 4df909c commit 4e708e5

File tree

11 files changed

+647
-22
lines changed

11 files changed

+647
-22
lines changed

tests/unit/vertex_rag/test_rag_constants.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,29 @@
1919
from google.cloud import aiplatform
2020

2121
from vertexai.rag import (
22+
Basic,
2223
Filter,
2324
LayoutParserConfig,
2425
LlmParserConfig,
2526
LlmRanker,
2627
Pinecone,
2728
RagCorpus,
29+
RagEngineConfig,
2830
RagFile,
31+
RagManagedDbConfig,
2932
RagResource,
3033
RagRetrievalConfig,
3134
RagVectorDbConfig,
3235
Ranking,
3336
RankService,
37+
Scaled,
3438
SharePointSource,
3539
SharePointSources,
3640
SlackChannelsSource,
3741
SlackChannel,
3842
JiraSource,
3943
JiraQuery,
44+
Unprovisioned,
4045
VertexVectorSearch,
4146
RagEmbeddingModelConfig,
4247
VertexAiSearchConfig,
@@ -45,9 +50,11 @@
4550

4651
from google.cloud.aiplatform_v1 import (
4752
GoogleDriveSource,
53+
RagEngineConfig as GapicRagEngineConfig,
4854
RagFileChunkingConfig,
4955
RagFileParsingConfig,
5056
RagFileTransformationConfig,
57+
RagManagedDbConfig as GapicRagManagedDbConfig,
5158
ImportRagFilesConfig,
5259
ImportRagFilesRequest,
5360
ImportRagFilesResponse,
@@ -677,6 +684,45 @@
677684
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_LLM_PARSER,
678685
)
679686

687+
# RagEngineConfig Resource
688+
TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME = (
689+
f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragEngineConfig"
690+
)
691+
TEST_RAG_ENGINE_CONFIG_BASIC = RagEngineConfig(
692+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
693+
rag_managed_db_config=RagManagedDbConfig(tier=Basic()),
694+
)
695+
TEST_RAG_ENGINE_CONFIG_SCALED = RagEngineConfig(
696+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
697+
rag_managed_db_config=RagManagedDbConfig(tier=Scaled()),
698+
)
699+
TEST_RAG_ENGINE_CONFIG_UNPROVISIONED = RagEngineConfig(
700+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
701+
rag_managed_db_config=RagManagedDbConfig(tier=Unprovisioned()),
702+
)
703+
TEST_DEFAULT_RAG_ENGINE_CONFIG = RagEngineConfig(
704+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
705+
rag_managed_db_config=None,
706+
)
707+
TEST_GAPIC_RAG_ENGINE_CONFIG_BASIC = GapicRagEngineConfig(
708+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
709+
rag_managed_db_config=GapicRagManagedDbConfig(
710+
basic=GapicRagManagedDbConfig.Basic()
711+
),
712+
)
713+
TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED = GapicRagEngineConfig(
714+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
715+
rag_managed_db_config=GapicRagManagedDbConfig(
716+
scaled=GapicRagManagedDbConfig.Scaled()
717+
),
718+
)
719+
TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED = GapicRagEngineConfig(
720+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
721+
rag_managed_db_config=GapicRagManagedDbConfig(
722+
unprovisioned=GapicRagManagedDbConfig.Unprovisioned()
723+
),
724+
)
725+
680726
# Inline Citations test constants
681727
TEST_ORIGINAL_TEXT = (
682728
"You can activate the parking radar using a switch or through the"

tests/unit/vertex_rag/test_rag_constants_preview.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,12 @@
6767
RagVectorDbConfig,
6868
RankService,
6969
Ranking,
70+
Scaled,
7071
SharePointSource,
7172
SharePointSources,
7273
SlackChannel,
7374
SlackChannelsSource,
75+
Unprovisioned,
7476
VertexAiSearchConfig,
7577
VertexFeatureStore,
7678
VertexPredictionEndpoint,
@@ -561,6 +563,14 @@
561563
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
562564
rag_managed_db_config=RagManagedDbConfig(tier=Basic()),
563565
)
566+
TEST_RAG_ENGINE_CONFIG_SCALED = RagEngineConfig(
567+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
568+
rag_managed_db_config=RagManagedDbConfig(tier=Scaled()),
569+
)
570+
TEST_RAG_ENGINE_CONFIG_UNPROVISIONED = RagEngineConfig(
571+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
572+
rag_managed_db_config=RagManagedDbConfig(tier=Unprovisioned()),
573+
)
564574
TEST_RAG_ENGINE_CONFIG_ENTERPRISE = RagEngineConfig(
565575
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
566576
rag_managed_db_config=RagManagedDbConfig(tier=Enterprise()),
@@ -575,6 +585,18 @@
575585
basic=GapicRagManagedDbConfig.Basic()
576586
),
577587
)
588+
TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED = GapicRagEngineConfig(
589+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
590+
rag_managed_db_config=GapicRagManagedDbConfig(
591+
scaled=GapicRagManagedDbConfig.Scaled()
592+
),
593+
)
594+
TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED = GapicRagEngineConfig(
595+
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
596+
rag_managed_db_config=GapicRagManagedDbConfig(
597+
unprovisioned=GapicRagManagedDbConfig.Unprovisioned()
598+
),
599+
)
578600
TEST_GAPIC_RAG_ENGINE_CONFIG_ENTERPRISE = GapicRagEngineConfig(
579601
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
580602
rag_managed_db_config=GapicRagManagedDbConfig(

tests/unit/vertex_rag/test_rag_data.py

Lines changed: 206 additions & 0 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,113 @@ def list_rag_corpora_pager_mock():
206206
yield list_rag_corpora_pager_mock
207207

208208

209+
@pytest.fixture()
210+
def update_rag_engine_config_basic_mock():
211+
with mock.patch.object(
212+
VertexRagDataServiceClient,
213+
"update_rag_engine_config",
214+
) as update_rag_engine_config_basic_mock:
215+
update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation)
216+
update_rag_engine_config_lro_mock.done.return_value = True
217+
update_rag_engine_config_lro_mock.result.return_value = (
218+
test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_BASIC
219+
)
220+
update_rag_engine_config_basic_mock.return_value = (
221+
update_rag_engine_config_lro_mock
222+
)
223+
yield update_rag_engine_config_basic_mock
224+
225+
226+
@pytest.fixture()
227+
def update_rag_engine_config_scaled_mock():
228+
with mock.patch.object(
229+
VertexRagDataServiceClient,
230+
"update_rag_engine_config",
231+
) as update_rag_engine_config_scaled_mock:
232+
update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation)
233+
update_rag_engine_config_lro_mock.done.return_value = True
234+
update_rag_engine_config_lro_mock.result.return_value = (
235+
test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED
236+
)
237+
update_rag_engine_config_scaled_mock.return_value = (
238+
update_rag_engine_config_lro_mock
239+
)
240+
yield update_rag_engine_config_scaled_mock
241+
242+
243+
@pytest.fixture()
244+
def update_rag_engine_config_unprovisioned_mock():
245+
with mock.patch.object(
246+
VertexRagDataServiceClient,
247+
"update_rag_engine_config",
248+
) as update_rag_engine_config_unprovisioned_mock:
249+
update_rag_engine_config_lro_mock = mock.Mock(ga_operation.Operation)
250+
update_rag_engine_config_lro_mock.done.return_value = True
251+
update_rag_engine_config_lro_mock.result.return_value = (
252+
test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED
253+
)
254+
update_rag_engine_config_unprovisioned_mock.return_value = (
255+
update_rag_engine_config_lro_mock
256+
)
257+
yield update_rag_engine_config_unprovisioned_mock
258+
259+
260+
@pytest.fixture()
261+
def update_rag_engine_config_mock_exception():
262+
with mock.patch.object(
263+
VertexRagDataServiceClient,
264+
"update_rag_engine_config",
265+
) as update_rag_engine_config_mock_exception:
266+
update_rag_engine_config_mock_exception.side_effect = Exception
267+
yield update_rag_engine_config_mock_exception
268+
269+
270+
@pytest.fixture()
271+
def get_rag_engine_basic_config_mock():
272+
with mock.patch.object(
273+
VertexRagDataServiceClient,
274+
"get_rag_engine_config",
275+
) as get_rag_engine_basic_config_mock:
276+
get_rag_engine_basic_config_mock.return_value = (
277+
test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_BASIC
278+
)
279+
yield get_rag_engine_basic_config_mock
280+
281+
282+
@pytest.fixture()
283+
def get_rag_engine_scaled_config_mock():
284+
with mock.patch.object(
285+
VertexRagDataServiceClient,
286+
"get_rag_engine_config",
287+
) as get_rag_engine_scaled_config_mock:
288+
get_rag_engine_scaled_config_mock.return_value = (
289+
test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_SCALED
290+
)
291+
yield get_rag_engine_scaled_config_mock
292+
293+
294+
@pytest.fixture()
295+
def get_rag_engine_unprovisioned_config_mock():
296+
with mock.patch.object(
297+
VertexRagDataServiceClient,
298+
"get_rag_engine_config",
299+
) as get_rag_engine_unprovisioned_config_mock:
300+
get_rag_engine_unprovisioned_config_mock.return_value = (
301+
test_rag_constants.TEST_GAPIC_RAG_ENGINE_CONFIG_UNPROVISIONED
302+
)
303+
yield get_rag_engine_unprovisioned_config_mock
304+
305+
306+
@pytest.fixture()
307+
def get_rag_engine_config_mock_exception():
308+
with mock.patch.object(
309+
VertexRagDataServiceClient,
310+
"get_rag_engine_config",
311+
) as get_rag_engine_config_mock_exception:
312+
get_rag_engine_config_mock_exception.side_effect = Exception
313+
yield get_rag_engine_config_mock_exception
314+
315+
209316
class MockResponse:
210317
def __init__(self, json_data, status_code):
211318
self.json_data = json_data
@@ -355,6 +462,13 @@ def import_files_request_eq(returned_request, expected_request):
355462
)
356463

357464

465+
def rag_engine_config_eq(returned_config, expected_config):
466+
assert returned_config.name == expected_config.name
467+
assert returned_config.rag_managed_db_config.__eq__(
468+
expected_config.rag_managed_db_config
469+
)
470+
471+
358472
@pytest.mark.usefixtures("google_auth_mock")
359473
class TestRagDataManagement:
360474
def setup_method(self):
@@ -1084,3 +1198,95 @@ def test_set_embedding_model_config_wrong_endpoint_format_error(self):
10841198
test_rag_constants.TEST_GAPIC_RAG_CORPUS,
10851199
)
10861200
e.match("endpoint must be of the format ")
1201+
1202+
def test_update_rag_engine_config_success(
1203+
self, update_rag_engine_config_basic_mock
1204+
):
1205+
rag_config = rag.update_rag_engine_config(
1206+
rag_engine_config=test_rag_constants.TEST_RAG_ENGINE_CONFIG_BASIC,
1207+
)
1208+
assert update_rag_engine_config_basic_mock.call_count == 1
1209+
rag_engine_config_eq(
1210+
rag_config,
1211+
test_rag_constants.TEST_RAG_ENGINE_CONFIG_BASIC,
1212+
)
1213+
1214+
def test_update_rag_engine_config_scaled_success(
1215+
self, update_rag_engine_config_scaled_mock
1216+
):
1217+
rag_config = rag.update_rag_engine_config(
1218+
rag_engine_config=test_rag_constants.TEST_RAG_ENGINE_CONFIG_SCALED,
1219+
)
1220+
assert update_rag_engine_config_scaled_mock.call_count == 1
1221+
rag_engine_config_eq(
1222+
rag_config,
1223+
test_rag_constants.TEST_RAG_ENGINE_CONFIG_SCALED,
1224+
)
1225+
1226+
def test_update_rag_engine_config_unprovisioned_success(
1227+
self, update_rag_engine_config_unprovisioned_mock
1228+
):
1229+
rag_config = rag.update_rag_engine_config(
1230+
rag_engine_config=test_rag_constants.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED,
1231+
)
1232+
assert update_rag_engine_config_unprovisioned_mock.call_count == 1
1233+
rag_engine_config_eq(
1234+
rag_config,
1235+
test_rag_constants.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED,
1236+
)
1237+
1238+
@pytest.mark.usefixtures("update_rag_engine_config_mock_exception")
1239+
def test_update_rag_engine_config_failure(self):
1240+
with pytest.raises(RuntimeError) as e:
1241+
rag.update_rag_engine_config(
1242+
rag_engine_config=test_rag_constants.TEST_RAG_ENGINE_CONFIG_SCALED,
1243+
)
1244+
e.match("Failed in RagEngineConfig update due to")
1245+
1246+
@pytest.mark.usefixtures("update_rag_engine_config_basic_mock")
1247+
def test_update_rag_engine_config_bad_input(
1248+
self, update_rag_engine_config_basic_mock
1249+
):
1250+
rag_config = rag.update_rag_engine_config(
1251+
rag_engine_config=test_rag_constants.TEST_DEFAULT_RAG_ENGINE_CONFIG,
1252+
)
1253+
assert update_rag_engine_config_basic_mock.call_count == 1
1254+
rag_engine_config_eq(
1255+
rag_config,
1256+
test_rag_constants.TEST_RAG_ENGINE_CONFIG_BASIC,
1257+
)
1258+
1259+
@pytest.mark.usefixtures("get_rag_engine_basic_config_mock")
1260+
def test_get_rag_engine_config_success(self):
1261+
rag_config = rag.get_rag_engine_config(
1262+
name=test_rag_constants.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
1263+
)
1264+
rag_engine_config_eq(
1265+
rag_config, test_rag_constants.TEST_RAG_ENGINE_CONFIG_BASIC
1266+
)
1267+
1268+
@pytest.mark.usefixtures("get_rag_engine_scaled_config_mock")
1269+
def test_get_rag_engine_config_scaled_success(self):
1270+
rag_config = rag.get_rag_engine_config(
1271+
name=test_rag_constants.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
1272+
)
1273+
rag_engine_config_eq(
1274+
rag_config, test_rag_constants.TEST_RAG_ENGINE_CONFIG_SCALED
1275+
)
1276+
1277+
@pytest.mark.usefixtures("get_rag_engine_unprovisioned_config_mock")
1278+
def test_get_rag_engine_config_unprovisioned_success(self):
1279+
rag_config = rag.get_rag_engine_config(
1280+
name=test_rag_constants.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
1281+
)
1282+
rag_engine_config_eq(
1283+
rag_config, test_rag_constants.TEST_RAG_ENGINE_CONFIG_UNPROVISIONED
1284+
)
1285+
1286+
@pytest.mark.usefixtures("get_rag_engine_config_mock_exception")
1287+
def test_get_rag_engine_config_failure(self):
1288+
with pytest.raises(RuntimeError) as e:
1289+
rag.get_rag_engine_config(
1290+
name=test_rag_constants.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
1291+
)
1292+
e.match("Failed in getting the RagEngineConfig due to")

0 commit comments

Comments
 (0)
0