@@ -743,6 +743,55 @@ def get_handler_for_metric(
743
743
raise ValueError (f"Unsupported metric: { metric .name } " )
744
744
745
745
746
+ def calculate_win_rates (eval_result : types .EvaluationResult ) -> dict [str , Any ]:
747
+ """Calculates win/tie rates for comparison results."""
748
+ if not eval_result .eval_case_results :
749
+ return {}
750
+ max_models = max (
751
+ (
752
+ len (case .response_candidate_results )
753
+ for case in eval_result .eval_case_results
754
+ if case .response_candidate_results
755
+ ),
756
+ default = 0 ,
757
+ )
758
+ if max_models == 0 :
759
+ return {}
760
+ stats = collections .defaultdict (
761
+ lambda : {"wins" : [0 ] * max_models , "ties" : 0 , "valid_comparisons" : 0 }
762
+ )
763
+ for case in eval_result .eval_case_results :
764
+ if not case .response_candidate_results :
765
+ continue
766
+ scores_by_metric = collections .defaultdict (list )
767
+ for idx , candidate in enumerate (case .response_candidate_results ):
768
+ for name , res in (
769
+ candidate .metric_results .items () if candidate .metric_results else {}
770
+ ):
771
+ if res .score is not None :
<
10000
code> 772
+ scores_by_metric [name ].append ({"score" : res .score , "cand_idx" : idx })
773
+ for name , scores in scores_by_metric .items ():
774
+ if not scores :
775
+ continue
776
+ stats [name ]["valid_comparisons" ] += 1
777
+ max_score = max (s ["score" ] for s in scores )
778
+ winners = [s ["cand_idx" ] for s in scores if s ["score" ] == max_score ]
779
+ if len (winners ) == 1 :
780
+ stats [name ]["wins" ][winners [0 ]] += 1
781
+ else :
782
+ stats [name ]["ties" ] += 1
783
+ win_rates = {}
784
+ for name , metric_stats in stats .items ():
785
+ if metric_stats ["valid_comparisons" ] > 0 :
786
+ win_rates [name ] = {
787
+ "win_rates" : [
788
+ w / metric_stats ["valid_comparisons" ] for w in metric_stats ["wins" ]
789
+ ],
790
+ "tie_rate" : metric_stats ["ties" ] / metric_stats ["valid_comparisons" ],
791
+ }
792
+ return win_rates
793
+
794
+
746
795
def _aggregate_metric_results (
747
796
metric_handlers : list [MetricHandler ],
748
797
eval_case_results : list [types .EvalCaseResult ],
@@ -1001,18 +1050,27 @@ def compute_metrics_and_aggregate(
1001
1050
)
1002
1051
final_eval_case_results .append (eval_case_result )
1003
1052
1004
- aggregated_metric_results = _aggregate_metric_results (
1005
- metric_handlers , final_eval_case_results
1006
- )
1007
-
1008
1053
if submission_errors :
1009
1054
logger .warning ("Encountered %d submission errors." , len (submission_errors ))
1010
1055
logger .warning ("Submission errors: %s" , submission_errors )
1011
1056
if execution_errors :
1012
1057
logger .warning ("Encountered %d execution errors." , len (execution_errors ))
1013
1058
logger .warning ("Execution errors: %s" , execution_errors )
1014
1059
1015
- return types .EvaluationResult (
1060
+ aggregated_metric_results = _aggregate_metric_results (
1061
+ metric_handlers , final_eval_case_results
1062
+ )
1063
+ eval_result = types .EvaluationResult (
1016
1064
eval_case_results = final_eval_case_results ,
1017
1065
summary_metrics = aggregated_metric_results ,
1018
1066
)
1067
+ if evaluation_run_config .num_response_candidates > 1 :
1068
+ try :
1069
+ eval_result .win_rates = calculate_win_rates (eval_result )
1070
+ except Exception as e : # pylint: disable=broad-exception-caught
1071
+ logger .error (
1072
+ "Error calculating win rates: %s" ,
1073
+ e ,
1074
+ exc_info = True ,
1075
+ )
1076
+ return eval_result
0 commit comments