8000 Update AgentEvaluator to new new EvalSchema · ditinagrawal/adk-python@4c6820e · GitHub
[go: up one dir, main page]

Skip to content

Commit 4c6820e

Browse files
ankursharmascopybara-github
authored andcommitted
Update AgentEvaluator to new new EvalSchema
PiperOrigin-RevId: 759293759
1 parent bdd678d commit 4c6820e

File tree

9 files changed

+152
-148
lines changed

9 files changed

+152
-148
lines changed

src/google/adk/cli/cli_eval.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
import logging
1818
import os
1919
import sys
20-
import traceback
2120
from typing import Any
2221
from typing import AsyncGenerator
23-
from typing import cast
2422
from typing import Optional
2523
import uuid
2624

@@ -350,7 +348,7 @@ def _get_evaluator(eval_metric: EvalMetric) -> Evaluator:
350348
return TrajectoryEvaluator(threshold=eval_metric.threshold)
351349
elif (
352350
eval_metric.metric_name == RESPONSE_MATCH_SCORE_KEY
353-
or eval_metric == RESPONSE_EVALUATION_SCORE_KEY
351+
or eval_metric.metric_name == RESPONSE_EVALUATION_SCORE_KEY
354352
):
355353
return ResponseEvaluator(
356354
threshold=eval_metric.threshold, metric_name=eval_metric.metric_name

src/google/adk/evaluation/agent_evaluator.py

Lines changed: 88 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,13 @@
1818
from typing import Dict
1919
from typing import List
2020
from typing import Union
21-
21+
import uuid
22+
from .eval_set import EvalSet
2223
from .evaluation_generator import EvaluationGenerator
24+
from .evaluator import EvalStatus
25+
from .evaluator import EvaluationResult
26+
from .evaluator import Evaluator
27+
from .local_eval_sets_manager import convert_eval_set_to_pydanctic_schema
2328
from .response_evaluator import ResponseEvaluator
2429
from .trajectory_evaluator import TrajectoryEvaluator
2530

@@ -75,6 +80,62 @@ def find_config_for_test_file(test_file: str):
7580
)
7681
return DEFAULT_CRITERIA
7782

83+
@staticmethod
84+
async def evaluate_eval_set(
85+
agent_module: str,
86+
eval_set: EvalSet,
87+
criteria: dict[str, float],
88+
num_runs=NUM_RUNS,
89+
agent_name=None,
90+
):
91+
"""Evaluates an agent using the given EvalSet.
92+
93+
Args:
94+
agent_module: The path to python module that contains the definition of
95+
the agent. There is convention in place here, where the code is going to
96+
look for 'root_agent' in the loaded module.
97+
eval_set: The eval set.
98+
criteria: Evauation criterias, a dictionary of metric names to their
99+
respective thresholds.
100+
num_runs: Number of times all entries in the eval dataset should be
101+
assessed.
102+
agent_name: The name of the agent.
103+
"""
104+
eval_case_responses_list = await EvaluationGenerator.generate_responses(
105+
eval_set=eval_set,
106+
agent_module_path=agent_module,
107+
repeat_num=num_runs,
108+
agent_name=agent_name,
109+
)
110+
111+
for eval_case_responses in eval_case_responses_list:
112+
actual_invocations = [
113+
invocation
114+
for invocations in eval_case_responses.responses
115+
for invocation in invocations
116+
]
117+
expected_invocations = (
118+
eval_case_responses.eval_case.conversation * num_runs
119+
)
120+
121+
for metric_name, threshold in criteria.items():
122+
metric_evaluator = AgentEvaluator._get_metric_evaluator(
123+
metric_name=metric_name, threshold=threshold
124+
)
125+
126+
evaluation_result: EvaluationResult = (
127+
metric_evaluator.evaluate_invocations(
128+
actual_invocations=actual_invocations,
129+
expected_invocations=expected_invocations,
130+
)
131+
)
132+
133+
assert evaluation_result.overall_eval_status == EvalStatus.PASSED, (
134+
f"`{eval_case_responses.eval_case.eval_id}`: "
135+
f"{metric_name} for {agent_module} Failed. Expected {threshold},"
136+
f" but got {evaluation_result.overall_score}."
137+
)
138+
78139
@staticmethod
79140
async def evaluate(
80141
agent_module,
@@ -109,35 +170,33 @@ async def evaluate(
109170
else:
110171
test_files = [eval_dataset_file_path_or_dir]
111172

112-
initial_session_state = {}
173+
initial_session = {}
113174
if initial_session_file:
114175
with open(initial_session_file, "r") as f:
115-
initial_session_state = json.loads(f.read())["state"]
176+
initial_session = json.loads(f.read())
116177

117178
for test_file in test_files:
118-
dataset = AgentEvaluator._load_dataset(test_file)[0]
179+
data = AgentEvaluator._load_dataset(test_file)[0]
119180
criteria = AgentEvaluator.find_config_for_test_file(test_file)
181+
AgentEvaluator._validate_input([data], criteria)
120182

121-
AgentEvaluator._validate_input([dataset], criteria)
183+
eval_data = {
184+
"name": test_file,
185+
"data": data,
186+
"initial_session": initial_session,
187+
}
122188

123-
evaluation_response = await AgentEvaluator._generate_responses(
124-
agent_module,
125-
[dataset],
126-
num_runs,
189+
eval_set = convert_eval_set_to_pydanctic_schema(
190+
eval_set_id=str(uuid.uuid4()), eval_set_in_json_format=[eval_data]
191+
)
192+
await AgentEvaluator.evaluate_eval_set(
193+
agent_module=agent_module,
194+
eval_set=eval_set,
195+
criteria=criteria,
196+
num_runs=num_runs,
127197
agent_name=agent_name,
128-
initial_session={"state": initial_session_state},
129198
)
130199

131-
if AgentEvaluator._response_evaluation_required(criteria, [dataset]):
132-
AgentEvaluator._evaluate_response_scores(
133-
agent_module, evaluation_response, criteria
134-
)
135-
136-
if AgentEvaluator._trajectory_evaluation_required(criteria, [dataset]):
137-
AgentEvaluator._evaluate_tool_trajectory(
138-
agent_module, evaluation_response, criteria
139-
)
140-
141200
@staticmethod
142201
def _load_dataset(
143202
input_data: Union[str, List[str], List[Dict], List[List[Dict]]],
@@ -221,102 +280,13 @@ def _validate_input(eval_dataset, criteria):
221280
)
222281

223282
@staticmethod
224-
def _get_infer_criteria(eval_dataset):
225-
"""Infers evaluation criteria based on the provided dataset.
226-
227-
Args:
228-
eval_dataset (list): A list of evaluation samples.
229-
230-
Returns:
231-
dict: Inferred evaluation criteria based on dataset fields.
232-
"""
233-
inferred_criteria = {}
234-
sample = eval_dataset[0][0]
235-
236-
if QUERY_COLUMN in sample and EXPECTED_TOOL_USE_COLUMN in sample:
237-
inferred_criteria[TOOL_TRAJECTO F987 RY_SCORE_KEY] = DEFAULT_CRITERIA[
238-
TOOL_TRAJECTORY_SCORE_KEY
239-
]
240-
241-
if QUERY_COLUMN in sample and REFERENCE_COLUMN in sample:
242-
inferred_criteria[RESPONSE_MATCH_SCORE_KEY] = DEFAULT_CRITERIA[
243-
RESPONSE_MATCH_SCORE_KEY
244-
]
245-
246-
return inferred_criteria
247-
248-
@staticmethod
249-
async def _generate_responses(
250-
agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
251-
):
252-
"""Generates evaluation responses by running the agent module multiple times."""
253-
return EvaluationGenerator.generate_responses(
254-
eval_dataset,
255-
agent_module,
256-
repeat_num=num_runs,
257-
agent_name=agent_name,
258-
initial_session=initial_session,
259-
)
260-
261-
@staticmethod
262-
def _response_evaluation_required(criteria, eval_dataset):
263-
"""Checks if response evaluation are needed."""
264-
return REFERENCE_COLUMN in eval_dataset[0][0] and any(
265-
key in criteria
266-
for key in [RESPONSE_EVALUATION_SCORE_KEY, RESPONSE_MATCH_SCORE_KEY]
267-
)
268-
269-
@staticmethod
270-
def _trajectory_evaluation_required(evaluation_criteria, eval_dataset):
271-
"""Checks if response evaluation are needed."""
272-
return (
273-
EXPECTED_TOOL_USE_COLUMN in eval_dataset[0][0]
274-
and TOOL_TRAJECTORY_SCORE_KEY in evaluation_criteria
275-
)
276-
277-
@staticmethod
278-
def _evaluate_response_scores(agent_module, evaluation_response, criteria):
279-
"""Evaluates response scores and raises an assertion error if they don't meet the criteria."""
280-
metrics = ResponseEvaluator.evaluate(
281-
evaluation_response, criteria, print_detailed_results=True
282-
)
283-
284-
AgentEvaluator._assert_score(
285-
metrics,
286-
"coherence/mean",
287-
criteria.get(RESPONSE_EVALUATION_SCORE_KEY),
288-
"Average response evaluation score",
289-
agent_module,
290-
)
291-
292-
AgentEvaluator._assert_score(
293-
metrics,
294-
"rouge_1/mean",
295-
criteria.get(RESPONSE_MATCH_SCORE_KEY),
296-
"Average response match score",
297-
agent_module,
298-
)
299-
300-
@staticmethod
301-
def _evaluate_tool_trajectory(agent_module, evaluation_response, criteria):
302-
"""Evaluates tool trajectory scores and raises an assertion error if they don't meet the criteria."""
303-
score = TrajectoryEvaluator.evaluate(
304-
evaluation_response, print_detailed_results=True
305-
)
306-
AgentEvaluator._assert_score(
307-
{TOOL_TRAJECTORY_SCORE_KEY: score},
308-
TOOL_TRAJECTORY_SCORE_KEY,
309-
criteria[TOOL_TRAJECTORY_SCORE_KEY],
310-
"Average tool trajectory evaluation score",
311-
agent_module,
312-
)
283+
def _get_metric_evaluator(metric_name: str, threshold: float) -> Evaluator:
284+
if metric_name == TOOL_TRAJECTORY_SCORE_KEY:
285+
return TrajectoryEvaluator(threshold=threshold)
286+
elif (
287+
metric_name == RESPONSE_MATCH_SCORE_KEY
288+
or metric_name == RESPONSE_EVALUATION_SCORE_KEY
289+
):
290+
return ResponseEvaluator(threshold=threshold, metric_name=metric_name)
313291

314-
@staticmethod
315-
def _assert_score(metrics, metric_key, threshold, description, agent_module):
316-
"""Asserts that a metric meets the specified threshold."""
317-
if metric_key in metrics:
318-
actual_score = metrics[metric_key]
319-
assert actual_score >= threshold, (
320-
f"{description} for {agent_module} is lower than expected. "
321-
f"Expected >= {threshold}, but got {actual_score}."
322-
)
292+
raise ValueError(f"Unsupported eval metric: {metric_name}")

src/google/adk/evaluation/evaluation_generator.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,46 @@
1313
# limitations under the License.
1414

1515
import importlib
16-
from typing import Any, Optional
16+
from typing import Any
17+
from typing import Optional
1718
import uuid
1819

20+
from pydantic import BaseModel
21+
1922
from ..agents.llm_agent import Agent
2023
from ..artifacts.base_artifact_service import BaseArtifactService
2124
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
2225
from ..runners import Runner
2326
from ..sessions.base_session_service import BaseSessionService
2427
from ..sessions.in_memory_session_service import InMemorySessionService
2528
from ..sessions.session import Session
29+
from .eval_case import EvalCase
2630
from .eval_case import IntermediateData
2731
from .eval_case import Invocation
2832
from .eval_case import SessionInput
33+
from .eval_set import EvalSet
34+
35+
36+
class EvalCaseResponses(BaseModel):
37+
"""Contains multiple responses associated with an EvalCase.
38+
39+
Multiple responses are a result of repeated requests to genereate inferences.
40+
"""
41+
42+
eval_case: EvalCase
43+
responses: list[list[Invocation]]
2944

3045

3146
class EvaluationGenerator:
3247
"""Generates evaluation responses for agents."""
3348

3449
@staticmethod
3550
async def generate_responses(
36-
eval_dataset,
37-
agent_module_path,
38-
repeat_num=3,
39-
agent_name=None,
40-
initial_session={},
41-
):
51+
eval_set: EvalSet,
52+
agent_module_path: str,
53+
repeat_num: int = 3,
54+
agent_name: str = None,
55+
) -> list[EvalCaseResponses]:
4256
"""Returns evaluation responses for the given dataset and agent.
4357
4458
Args:
@@ -48,17 +62,23 @@ async def generate_responses(
4862
usually done to remove uncertainty that a single run may bring.
4963
agent_name: The name of the agent that should be evaluated. This is
5064
usually the sub-agent.
51-
initial_session: Initial session for the eval data.
5265
"""
5366
results = []
5467

55-
for _ in range(repeat_num):
56-
for data in eval_dataset:
57-
results.append(
58-
EvaluationGenerator._process_query(
59-
data, agent_module_path, agent_name, initial_session
60-
)
68+
for eval_case in eval_set.eval_cases:
69+
responses = []
70+
for _ in range(repeat_num):
71+
response_invocations = await EvaluationGenerator._process_query(
72+
eval_case.conversation,
73+
agent_module_path,
74+
agent_name,
75+
eval_case.session_input,
6176
)
77+
responses.append(response_invocations)
78+
79+
results.append(
80+
EvalCaseResponses(eval_case=eval_case, responses=responses)
81+
)
6282

6383
return results
6484

@@ -89,7 +109,12 @@ def generate_responses_from_session(session_path, eval_dataset):
89109
return results
90110

91111
@staticmethod
92-
def _process_query(data, module_name, agent_name=None, initial_session={}):
112+
async def _process_query(
113+
invocations: list[Invocation],
114+
module_name: str,
115+
agent_name: Optional[str] = None,
116+
initial_session: Optional[SessionInput] = None,
117+
) -> list[Invocation]:
93118
"""Process a query using the agent and evaluation dataset."""
94119
module_path = f"{module_name}"
95120
agent_module = importlib.import_module(module_path)
@@ -102,8 +127,8 @@ def _process_query(data, module_name, agent_name=None, initial_session={}):
102127
agent_to_evaluate = root_agent.find_agent(agent_name)
103128
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."
104129

105-
return EvaluationGenerator._generate_inferences_from_root_agent(
106-
data, agent_to_evaluate, reset_func, initial_session
130+
return await EvaluationGenerator._generate_inferences_from_root_agent(
131+
invocations, agent_to_evaluate, reset_func, initial_session
107132
)
108133

109134
@staticmethod
@@ -216,3 +241,5 @@ def _process_query_with_session(session_data, data):
216241
responses[index]["actual_tool_use"] = actual_tool_uses
217242
responses[index]["response"] = response
218243
return responses
244+
return responses
245+
return responses

0 commit comments

Comments
 (0)
0