8000 feat: GenAI SDK client - add `show` method for `EvaluationResult` and… · googleapis/python-aiplatform@c43de0a · GitHub
[go: up one dir, main page]

Skip to content

Commit c43de0a

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI SDK client - add show method for EvaluationResult and EvaluationDataset classes in IPython environment
PiperOrigin-RevId: 773042552
1 parent 726d3a2 commit c43de0a

File tree

5 files changed

+523
-15
lines changed

5 files changed

+523
-15
lines changed

vertexai/_genai/__init__.py

Lines changed: 1 addition & 1 deletion
10000
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __getattr__(name):
2929
_evals = importlib.import_module(".evals", __package__)
3030
except ImportError as e:
3131
raise ImportError(
32-
"The 'evals' module requires 'pandas' and 'tqdm'. "
32+
"The 'evals' module requires additional dependencies. "
3333
"Please install them using pip install "
3434
"google-cloud-aiplatform[evaluation]"
3535
) from e

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,55 @@ def get_handler_for_metric(
743743
raise ValueError(f"Unsupported metric: {metric.name}")
744744

745745

746+
def calculate_win_rates(eval_result: types.EvaluationResult) -> dict[str, Any]:
747+
"""Calculates win/tie rates for comparison results."""
748+
if not eval_result.eval_case_results:
749+
return {}
750+
max_models = max(
751+
(
752+
len(case.response_candidate_results)
753+
for case in eval_result.eval_case_results
754+
if case.response_candidate_results
755+
),
756+
default=0,
757+
)
758+
if max_models == 0:
759+
return {}
760+
stats = collections.defaultdict(
761+
lambda: {"wins": [0] * max_models, "ties": 0, "valid_comparisons": 0}
762+
)
763+
for case in eval_result.eval_case_results:
764+
if not case.response_candidate_results:
765+
continue
766+
scores_by_metric = collections.defaultdict(list)
767+
for idx, candidate in enumerate(case.response_candidate_results):
768+
for name, res in (
769+
candidate.metric_results.items() if candidate.metric_results else {}
770+
):
771+
if res.score is not None:
< 10000 code>772+
scores_by_metric[name].append({"score": res.score, "cand_idx": idx})
773+
for name, scores in scores_by_metric.items():
774+
if not scores:
775+
continue
776+
stats[name]["valid_comparisons"] += 1
777+
max_score = max(s["score"] for s in scores)
778+
winners = [s["cand_idx"] for s in scores if s["score"] == max_score]
779+
if len(winners) == 1:
780+
stats[name]["wins"][winners[0]] += 1
781+
else:
782+
stats[name]["ties"] += 1
783+
win_rates = {}
784+
for name, metric_stats in stats.items():
785+
if metric_stats["valid_comparisons"] > 0:
786+
win_rates[name] = {
787+
"win_rates": [
788+
w / metric_stats["valid_comparisons"] for w in metric_stats["wins"]
789+
],
790+
"tie_rate": metric_stats["ties"] / metric_stats["valid_comparisons"],
791+
}
792+
return win_rates
793+
794+
746795
def _aggregate_metric_results(
747796
metric_handlers: list[MetricHandler],
748797
eval_case_results: list[types.EvalCaseResult],
@@ -1001,18 +1050,27 @@ def compute_metrics_and_aggregate(
10011050
)
10021051
final_eval_case_results.append(eval_case_result)
10031052

1004-
aggregated_metric_results = _aggregate_metric_results(
1005-
metric_handlers, final_eval_case_results
1006-
)
1007-
10081053
if submission_errors:
10091054
logger.warning("Encountered %d submission errors.", len(submission_errors))
10101055
logger.warning("Submission errors: %s", submission_errors)
10111056
if execution_errors:
10121057
logger.warning("Encountered %d execution errors.", len(execution_errors))
10131058
logger.warning("Execution errors: %s", execution_errors)
10141059

1015-
return types.EvaluationResult(
1060+
aggregated_metric_results = _aggregate_metric_results(
1061+
metric_handlers, final_eval_case_results
1062+
)
1063+
eval_result = types.EvaluationResult(
10161064
eval_case_results=final_eval_case_results,
10171065
summary_metrics=aggregated_metric_results,
10181066
)
1067+
if evaluation_run_config.num_response_candidates > 1:
1068+
try:
1069+
eval_result.win_rates = calculate_win_rates(eval_result)
1070+
except Exception as e: # pylint: disable=broad-exception-caught
1071+
logger.error(
1072+
"Error calculating win rates: %s",
1073+
e,
1074+
exc_info=True,
1075+
)
1076+
return eval_result

0 commit comments

Comments
 (0)
0