8000 Update Response Evaluators to use the new eval schema. · DarioMR1/adk-python@ada24d7 · GitHub
[go: up one dir, main page]

Skip to content

Commit ada24d7

Browse files
ankursharmascopybara-github
authored andcommitted
Update Response Evaluators to use the new eval schema.
PiperOrigin-RevId: 758929683
1 parent ee674ce commit ada24d7

File tree

2 files changed

+149
-54
lines changed

2 files changed

+149
-54
lines changed

src/google/adk/cli/cli_eval.py

Lines changed: 43 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..evaluation.eval_case import EvalCase
3333
from ..evaluation.eval_case import Invocation
3434
from ..evaluation.evaluator import EvalStatus
35+
from ..evaluation.evaluator import Evaluator
3536
from ..sessions.base_session_service import BaseSessionService
3637
from ..sessions.session import Session
3738
from .utils import common
@@ -271,55 +272,32 @@ async def run_evals(
271272
overall_eval_metric_results = []
272273

273274
for eval_metric in eval_metrics:
274-
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
275-
evaluation_result = TrajectoryEvaluator(
276-
eval_metric.threshold
277-
).evaluate_invocations(
278-
actual_invocations=inference_result,
279-
expected_invocations=eval_case.conversation,
280-
)
281-
overall_eval_metric_results.append(
275+
metric_evaluator = _get_evaluator(eval_metric)
276+
277+
evaluation_result = metric_evaluator.evaluate_invocations(
278+
actual_invocations=inference_result,
279+
expected_invocations=eval_case.conversation,
280+
)
281+
282+
overall_eval_metric_results.append(
283+
EvalMetricResult(
284+
metric_name=eval_metric.metric_name,
285+
threshold=eval_metric.threshold,
286+
score=evaluation_result.overall_score,
287+
eval_status=evaluation_result.overall_eval_status,
288+
)
289+
)
290+
for index, per_invocation_result in enumerate(
291+
evaluation_result.per_invocation_results
292+
):
293+
eval_metric_result_per_invocation[index].eval_metric_results.append(
282294
EvalMetricResult(
283295
metric_name=eval_metric.metric_name,
284296
threshold=eval_metric.threshold,
285-
score=evaluation_result.overall_score,
286-
eval_status=evaluation_result.overall_eval_status,
297+
score=per_invocation_result.score,
298+
eval_status=per_invocation_result.eval_status,
287299
)
288300
)
289-
for index, per_invocation_result in enumerate(
290-
evaluation_result.per_invocation_results
291-
):
292-
eval_metric_result_per_invocation[
293-
index
294-
].eval_metric_results.append(
295-
EvalMetricResult(
296-
metric_name=eval_metric.metric_name,
297-
threshold=eval_metric.threshold,
298-
score=per_invocation_result.score,
299-
eval_status=per_invocation_result.eval_status,
300-
)
301-
)
302-
303-
# elif eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY:
304-
# score = ResponseEvaluator.evaluate(
305-
# [inference_result],
306-
# [RESPONSE_MATCH_SCORE_KEY],
307-
# print_detailed_results=print_detailed_results,
308-
# )
309-
# eval_metric_result = _get_eval_metric_result(
310-
# eval_metric, score["rouge_1/mean"].item()
311-
# )
312-
# elif eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY:
313-
# score = ResponseEvaluator.evaluate(
314-
# [inference_result],
315-
# [RESPONSE_EVALUATION_SCORE_KEY],
316-
# print_detailed_results=print_detailed_results,
317-
# )
318-
# eval_metric_result = _get_eval_metric_result(
319-
# eval_metric, score["coherence/mean"].item()
320-
# )
321-
else:
322-
logger.warning("`%s` is not supported.", eval_metric.metric_name)
323301

324302
final_eval_status = EvalStatus.NOT_EVALUATED
325303
# Go over the all the eval statuses and mark the final eval status as
@@ -356,13 +334,26 @@ async def run_evals(
356334

357335
print(f"Result: {result}\n")
358336

359-
except Exception as e:
360-
print(f"Error: {e}")
361-
logger.info("Error: %s", str(traceback.format_exc()))
337+
except Exception:
338+
# Catching the general exception, so that we don't block other eval
339+
# cases.
340+
logger.exception(f"Eval failed for `{eval_set_id}:{eval_name}`")
362341

363342

364-
def _get_eval_metric_result(eval_metric, score):
365-
eval_status = (
366-
EvalStatus.PASSED if score >= eval_metric.threshold else EvalStatus.FAILED
367-
)
368-
return EvalMetricResult(score=score, eval_status=eval_status)
343+
def _get_evaluator(eval_metric: EvalMetric) -> Evaluator:
344+
try:
345+
from ..evaluation.response_evaluator import ResponseEvaluator
346+
from ..evaluation.trajectory_evaluator import TrajectoryEvaluator
347+
except ModuleNotFoundError as e:
348+
raise ModuleNotFoundError(MISSING_EVAL_DEPENDENCIES_MESSAGE) from e
349+
if eval_metric.metric_name == TOOL_TRAJECTORY_SCORE_KEY:
350+
return TrajectoryEvaluator(threshold=eval_metric.threshold)
351+
elif (
352+
eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY
353+
or eval_metric == RESPONSE_EVALUATION_SCORE_KEY
354+
):
355+
return ResponseEvaluator(
356+
threshold=eval_metric.threshold, metric_name=eval_metric.metric_name
357+
)
358+
359+
raise ValueError(f"Unsupported eval metric: {eval_metric}")

src/google/adk/evaluation/response_evaluator.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,122 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any
15+
from typing import Any, Optional
1616

17+
from deprecated import deprecated
18+
from google.genai import types as genai_types
1719
import pandas as pd
1820
from tabulate import tabulate
21+
from typing_extensions import override
1922
from vertexai.preview.evaluation import EvalTask
2023
from vertexai.preview.evaluation import MetricPromptTemplateExamples
2124

25+
from .eval_case import IntermediateData
26+
from .eval_case import Invocation
27+
from .evaluator import EvalStatus
28+
from .evaluator import EvaluationResult
29+
from .evaluator import Evaluator
30+
from .evaluator import PerInvocationResult
2231

23-
class ResponseEvaluator:
32+
33+
class ResponseEvaluator(Evaluator):
2434
"""Runs response evaluation for agents."""
2535

36+
def __init__(self, threshold: float, metric_name: str):
37+
if "response_evaluation_score" == metric_name:
38+
self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
39+
elif "response_match_score" == metric_name:
40+
self._metric_name = "rouge_1"
41+
else:
42+
raise ValueError(f"`{metric_name}` is not supported.")
43+
44+
self._threshold = threshold
45+
46+
@override
47+
def evaluate_invocations(
48+
self,
49+
actual_invocations: list[Invocation],
50+
expected_invocations: list[Invocation],
51+
) -> EvaluationResult:
52+
total_score = 0.0
53+
num_invocations = 0
54+
per_invocation_results = []
55+
for actual, expected in zip(actual_invocations, expected_invocations):
56+
prompt = self._get_text(expected.user_content)
57+
reference = self._get_text(expected.final_response)
58+
response = self._get_text(actual.final_response)
59+
actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data)
60+
reference_trajectory = self._get_tool_use_trajectory(
61+
expected.intermediate_data
62+
)
63+
64+
eval_case = {
65+
"prompt": prompt,
66+
"reference": reference,
67+
"response": response,
68+
"actual_tool_user": actual_tool_use,
69+
"reference_trajectory": reference_trajectory,
70+
}
71+
72+
eval_case_result = ResponseEvaluator._perform_eval(
73+
pd.DataFrame([eval_case]), [self._metric_name]
74+
)
75+
score = self._get_score(eval_case_result)
76+
per_invocation_results.append(
77+
PerInvocationResult(
78+
actual_invocation=actual,
79+
expected_invocation=expected,
80+
score=score,
81+
eval_status=self._get_eval_status(score),
82+
)
83+
)
84+
total_score += score
85+
num_invocations += 1
86+
87+
if per_invocation_results:
88+
overall_score = total_score / num_invocations
89+
return EvaluationResult(
90+
overall_score=overall_score,
91+
overall_eval_status=self._get_eval_status(overall_score),
92+
per_invocation_results=per_invocation_results,
93+
)
94+
95+
return EvaluationResult()
96+
97+
def _get_text(self, content: Optional[genai_types.Content]) -> str:
98+
if content and content.parts:
99+
return "\n".join([p.text for p in content.parts if p.text])
100+
101+
return ""
102+
103+
def _get_tool_use_trajectory(
104+
self, intermediate_data: Optional[IntermediateData]
105+
) -> list[dict[str, Any]]:
106+
tool_use_trajectory = []
107+
if not intermediate_data:
108+
return tool_use_trajectory
109+
110+
for function_call in intermediate_data.tool_uses:
111+
tool_use_trajectory.append({
112+
"tool_name": function_call.name,
113+
"tool_input": function_call.args or {},
114+
})
115+
116+
return tool_use_trajectory
117+
118+
def _get_score(self, eval_result) -> float:
119+
return eval_result.summary_metrics[f"{self._metric_name}/mean"].item()
120+
121+
def _get_eval_status(self, score: float):
122+
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
123+
26124
@staticmethod
125+
@deprecated(
126+
reason=(
127+
"This method has been deprecated and will be removed soon. Please use"
128+
" evaluate_invocations instead."
129+
)
130+
)
27131
def evaluate(
28132
raw_eval_dataset: list[list[dict[str, Any]]],
29133
evaluation_criteria: list[str],

0 commit comments

Comments
 (0)
0