|
32 | 32 | from ..evaluation.eval_case import EvalCase
|
33 | 33 | from ..evaluation.eval_case import Invocation
|
34 | 34 | from ..evaluation.evaluator import EvalStatus
|
| 35 | +from ..evaluation.evaluator import Evaluator |
35 | 36 | from ..sessions.base_session_service import BaseSessionService
|
36 | 37 | from ..sessions.session import Session
|
37 | 38 | from .utils import common
|
@@ -271,55 +272,32 @@ async def run_evals(
|
271 | 272 | overall_eval_metric_results = []
|
272 | 273 |
|
273 | 274 | for eval_metric in eval_metrics:
|
274 |
| - if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY: |
275 |
| - evaluation_result = TrajectoryEvaluator( |
276 |
| - eval_metric.threshold |
277 |
| - ).evaluate_invocations( |
278 |
| - actual_invocations=inference_result, |
279 |
| - expected_invocations=eval_case.conversation, |
280 |
| - ) |
281 |
| - overall_eval_metric_results.append( |
| 275 | + metric_evaluator = _get_evaluator(eval_metric) |
| 276 | + |
| 277 | + evaluation_result = metric_evaluator.evaluate_invocations( |
| 278 | + actual_invocations=inference_result, |
| 279 | + expected_invocations=eval_case.conversation, |
| 280 | + ) |
| 281 | + |
| 282 | + overall_eval_metric_results.append( |
| 283 | + EvalMetricResult( |
| 284 | + metric_name=eval_metric.metric_name, |
| 285 | + threshold=eval_metric.threshold, |
| 286 | + score=evaluation_result.overall_score, |
| 287 | + eval_status=evaluation_result.overall_eval_status, |
| 288 | + ) |
| 289 | + ) |
| 290 | + for index, per_invocation_result in enumerate( |
| 291 | + evaluation_result.per_invocation_results |
| 292 | + ): |
| 293 | + eval_metric_result_per_invocation[index].eval_metric_results.append( |
282 | 294 | EvalMetricResult(
|
283 | 295 | metric_name=eval_metric.metric_name,
|
284 | 296 | threshold=eval_metric.threshold,
|
285 |
| - score=evaluation_result.overall_score, |
286 |
| - eval_status=evaluation_result.overall_eval_status, |
| 297 | + score=per_invocation_result.score, |
| 298 | + eval_status=per_invocation_result.eval_status, |
287 | 299 | )
|
288 | 300 | )
|
289 |
| - for index, per_invocation_result in enumerate( |
290 |
| - evaluation_result.per_invocation_results |
291 |
| - ): |
292 |
| - eval_metric_result_per_invocation[ |
293 |
| - index |
294 |
| - ].eval_metric_results.append( |
295 |
| - EvalMetricResult( |
296 |
| - metric_name=eval_metric.metric_name, |
297 |
| - threshold=eval_metric.threshold, |
298 |
| - score=per_invocation_result.score, |
299 |
| - eval_status=per_invocation_result.eval_status, |
300 |
| - ) |
301 |
| - ) |
302 |
| - |
303 |
| - # elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY: |
304 |
| - # score = ResponseEvaluator.evaluate( |
305 |
| - # [inference_result], |
306 |
| - # [RESPONSE_MATCH_SCORE_KEY], |
307 |
| - # print_detailed_results=print_detailed_results, |
308 |
| - # ) |
309 |
| - # eval_metric_result = _get_eval_metric_result( |
310 |
| - # eval_metric, score["rouge_1/mean"].item() |
311 |
| - # ) |
312 |
| - # elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY: |
313 |
| - # score = ResponseEvaluator.evaluate( |
314 |
| - # [inference_result], |
315 |
| - # [RESPONSE_EVALUATION_SCORE_KEY], |
316 |
| - # print_detailed_results=print_detailed_results, |
317 |
| - # ) |
318 |
| - # eval_metric_result = _get_eval_metric_result( |
319 |
| - # eval_metric, score["coherence/mean"].item() |
320 |
| - # ) |
321 |
| - else: |
322 |
| - logger.warning("`%s` is not supported.", eval_metric.metric_name) |
323 | 301 |
|
324 | 302 | final_eval_status = EvalStatus.NOT_EVALUATED
|
325 | 303 | # Go over the all the eval statuses and mark the final eval status as
|
@@ -356,13 +334,26 @@ async def run_evals(
|
356 | 334 |
|
357 | 335 | print(f"Result: {result}\n")
|
358 | 336 |
|
359 |
| - except Exception as e: |
360 |
| - print(f"Error: {e}") |
361 |
| - logger.info("Error: %s", str(traceback.format_exc())) |
| 337 | + except Exception: |
| 338 | + # Catching the general exception, so that we don't block other eval |
| 339 | + # cases. |
| 340 | + logger.exception(f"Eval failed for `{eval_set_id}:{eval_name}`") |
362 | 341 |
|
363 | 342 |
|
364 |
| -def _get_eval_metric_result(eval_metric, score): |
365 |
| - eval_status = ( |
366 |
| - EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED |
367 |
| - ) |
368 |
| - return EvalMetricResult(score=score, eval_status=eval_status) |
| 343 | +def _get_evaluator(eval_metric: EvalMetric) -> Evaluator: |
| 344 | + try: |
| 345 | + from ..evaluation.response_evaluator import ResponseEvaluator |
| 346 | + from ..evaluation.trajectory_evaluator import TrajectoryEvaluator |
| 347 | + except ModuleNotFoundError as e: |
| 348 | + raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e |
| 349 | + if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY: |
| 350 | + return TrajectoryEvaluator(threshold=eval_metric.threshold) |
| 351 | + elif ( |
| 352 | + eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY |
| 353 | + or eval_metric == RESPONSE_EVALUATION_SCORE_KEY |
| 354 | + ): |
| 355 | + return ResponseEvaluator( |
| 356 | + threshold=eval_metric.threshold, metric_name=eval_metric.metric_name |
| 357 | + ) |
| 358 | + |
| 359 | + raise ValueError(f"Unsupported eval metric: {eval_metric}") |
0 commit comments