@@ -1556,6 +1556,84 @@ def assess_batch_prediction_resources(
1556
1556
audio_token_count = assessment_result .audio_token_count ,
1557
1557
)
1558
1558
1559
+ def assess_batch_prediction_validity (
1560
+ self ,
1561
+ * ,
1562
+ model_name : str ,
1563
+ template_config : Optional [GeminiTemplateConfig ] = None ,
1564
+ assess_request_timeout : Optional [float ] = None ,
1565
+ ) -> None :
1566
+ """Assess if the assembled dataset is valid in terms of batch prediction
1567
+ for a given model. Raises an error if the dataset is invalid, otherwise
1568
+ returns None.
1569
+
1570
+ Args:
1571
+ model_name (str):
1572
+ Required. The name of the model to assess the batch prediction
1573
+ validity for.
1574
+ dataset_usage (str):
1575
+ Required. The dataset usage to assess the batch prediction
1576
+ validity for.
1577
+ Must be one of the following: SFT_TRAINING, SFT_VALIDATION.
1578
+ template_config (GeminiTemplateConfig):
1579
+ Optional. The template config used to assemble the dataset
1580
+ before assessing the batch prediction validity. If not provided, the
1581
+ template config attached to the dataset will be used. Required
1582
+ if no template config is attached to the dataset.
1583
+ assess_request_timeout (float):
1584
+ Optional. The timeout for the assess batch prediction validity request.
1585
+ """
1586
+ request = self ._build_assess_data_request (template_config )
1587
+ request .batch_prediction_validation_assessment_config = gca_dataset_service .AssessDataRequest .BatchPredictionValidationAssessmentConfig (
1588
+ model_name = model_name ,
1589
+ )
1590
+ assess_lro = self .api_client .assess_data (
1591
+ request = request , timeout = assess_request_timeout
1592
+ )
1593
+ assess_lro .result (timeout = None )
1594
+
1595
+ def assess_batch_prediction_resources (
1596
+ self ,
1597
+ * ,
1598
+ model_name : str ,
1599
+ template_config : Optional [GeminiTemplateConfig ] = None ,
1600
+ assess_request_timeout : Optional [float ] = None ,
1601
+ ) -> BatchPredictionResourceUsageAssessmentResult :
1602
+ """Assess the batch prediction resources required for a given model.
1603
+
1604
+ Args:
1605
+ model_name (str):
1606
+ Required. The name of the model to assess the batch prediction resources
1607
+ for.
1608
+ template_config (GeminiTemplateConfig):
1609
+ Optional. The template config used to assemble the dataset
1610
+ before assessing the batch prediction resources. If not provided, the
1611
+ template config attached to the dataset will be used. Required
1612
+ if no template config is attached to the dataset.
1613
+ assess_request_timeout (float):
1614
+ Optional. The timeout for the assess batch prediction resources request.
1615
+ Returns:
1616
+ A dict containing the batch prediction resource usage assessment result. The
1617
+ dict contains the following keys:
1618
+ - token_count: The number of tokens in the dataset.
1619
+ - audio_token_count: The number of audio tokens in the dataset.
1620
+
1621
+ """
1622
+ request = self ._build_assess_data_request (template_config )
1623
+ request .batch_prediction_resource_usage_assessment_config = gca_dataset_service .AssessDataRequest .BatchPredictionResourceUsageAssessmentConfig (
1624
+ model_name = model_name
1625
+ )
1626
+
1627
+ assessment_result = (
1628
+ self .api_client .assess_data (request = request , timeout = assess_request_timeout )
1629
+ .result (timeout = None )
1630
+ .batch_prediction_resource_usage_assessment_result
1631
+ )
1632
+ return BatchPredictionResourceUsageAssessmentResult (
1633
+ token_count = assessment_result .token_count ,
1634
+ audio_token_count = assessment_result .audio_token_count ,
1635
+ )
1636
+
1559
1637
def _build_assess_data_request (
1560
1638
self ,
1561
1639
template_config : Optional [GeminiTemplateConfig ] = None ,
0 commit comments