10000 Refactor Eval Set Management into its own class. · Syntax404-coder/adk-python@cf06cc5 · GitHub
[go: up one dir, main page]

Skip to content

Commit cf06cc5

Browse files
ankursharmascopybara-github
authored andcommitted
Refactor Eval Set Management into its own class.
PiperOrigin-RevId: 758378377
1 parent 303af44 commit cf06cc5

File tree

3 files changed

+162
-62
lines changed

3 files changed

+162
-62
lines changed

src/google/adk/cli/fast_api.py

Lines changed: 16 additions & 62 deletions
D7AE
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from ..agents.llm_agent import LlmAgent
6464
from ..agents.run_config import StreamingMode
6565
from ..artifacts import InMemoryArtifactService
66+
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
6667
from ..events.event import Event
6768
from ..memory.in_memory_memory_service import InMemoryMemoryService
6869
from ..runners import Runner
@@ -252,6 +253,8 @@ async def internal_lifespan(app: FastAPI):
252253
artifact_service = InMemoryArtifactService()
253254
memory_service = InMemoryMemoryService()
254255

256+
eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir)
257+
255258
# Build the Session service
256259
agent_engine_id = ""
257260
if session_db_url:
@@ -401,44 +404,21 @@ def create_eval_set(
401404
eval_set_id: str,
402405
):
403406
"""Creates an eval set, given the id."""
404-
pattern = r"^[a-zA-Z0-9_]+$"
405-
if not bool(re.fullmatch(pattern, eval_set_id)):
407+
try:
408+
eval_sets_manager.create_eval_set(app_name, eval_set_id)
409+
except ValueError as ve:
406410
raise HTTPException(
407411
status_code=400,
408-
detail=(
409-
f"Invalid eval set id. Eval set id should have the `{pattern}`"
410-
" format"
411-
),
412-
)
413-
# Define the file path
414-
new_eval_set_path = _get_eval_set_file_path(
415-
app_name, agent_dir, eval_set_id
416-
)
417-
418-
logger.info("Creating eval set file `%s`", new_eval_set_path)
419-
420-
if not os.path.exists(new_eval_set_path):
421-
# Write the JSON string to the file
422-
logger.info("Eval set file doesn't exist, we will create a new one.")
423-
with open(new_eval_set_path, "w") as f:
424-
empty_content = json.dumps([], indent=2)
425-
f.write(empty_content)
412+
detail=str(ve),
413+
) from ve
426414

427415
@app.get(
428416
"/apps/{app_name}/eval_sets",
429417
response_model_exclude_none=True,
430418
)
431419
def list_eval_sets(app_name: str) -> list[str]:
432420
"""Lists all eval sets for the given app."""
433-
eval_set_file_path = os.path.join(agent_dir, app_name)
434-
eval_sets = []
435-
for file in os.listdir(eval_set_file_path):
436-
if file.endswith(_EVAL_SET_FILE_EXTENSION):
437-
eval_sets.append(
438-
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
439-
)
440-
441-
return sorted(eval_sets)
421+
return eval_sets_manager.list_eval_sets(app_name)
442422

443423
@app.post(
444424
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
@@ -447,33 +427,11 @@ def list_eval_sets(app_name: str) -> list[str]:
447427
async def add_session_to_eval_set(
448428
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
449429
):
450-
pattern = r"^[a-zA-Z0-9_]+$"
451-
if not bool(re.fullmatch(pattern, req.eval_id)):
452-
raise HTTPException(
453-
status_code=400,
454-
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
455-
)
456-
457430
# Get the session
458431
session = session_service.get_session(
459432
app_name=app_name, user_id=req.user_id, session_id=req.session_id
460433
)
461434
assert session, "Session not found."
462-
# Load the eval set file data
463-
eval_set_file_path = _get_eval_set_file_path(
464-
app_name, agent_dir, eval_set_id
465-
)
466-
with open(eval_set_file_path, "r") as file:
467-
eval_set_data = json.load(file) # Load JSON into a list
468-
469-
if [x for x in eval_set_data if x["name"] == req.eval_id]:
470-
raise HTTPException(
471-
status_code=400,
472-
detail=(
473-
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
474-
" eval set."
475-
),
476-
)
477435

478436
# Convert the session data to evaluation format
479437
test_data = evals.convert_session_to_eval_format(session)
@@ -483,18 +441,19 @@ async def add_session_to_eval_set(
483441
await _get_root_agent_async(app_name)
484442
)
485443

486-
eval_set_data.append({
444+
eval_case = {
487445
"name": req.eval_id,
488446
"data": test_data,
489447
"initial_session": {
490448
"state": initial_session_state,
491449
"app_name": app_name,
492450
"user_id": req.user_id,
493451
},
494-
})
495-
# Serialize the test data to JSON and write to the eval set file.
496-
with open(eval_set_file_path, "w") as f:
497-
f.write(json.dumps(eval_set_data, indent=2))
452+
}
453+
try:
454+
eval_sets_manager.add_eval_case(app_name, eval_set_id, eval_case)
455+
except ValueError as ve:
456+
raise HTTPException(status_code=400, detail=str(ve)) from ve
498457

499458
@app.get(
500459
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
@@ -505,12 +464,7 @@ def list_evals_in_eval_set(
505464
eval_set_id: str,
506465
) -> list[str]:
507466
"""Lists all evals in an eval set."""
508-
# Load the eval set file data
509-
eval_set_file_path = _get_eval_set_file_path(
510-
app_name, agent_dir, eval_set_id
511-
)
512-
with open(eval_set_file_path, "r") as file:
513-
eval_set_data = json.load(file) # Load JSON into a list
467+
eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id)
514468

515469
return sorted([x["name"] for x in eval_set_data])
516470

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from typing import Any
17+
18+
19+
class EvalSetsManager(ABC):
20+
"""An interface to manage an Eval Sets."""
21+
22+
@abstractmethod
23+
def get_eval_set(self, app_name: str, eval_set_id: str) -> Any:
24+
"""Returns an EvalSet identified by an app_name and eval_set_id."""
25+
raise NotImplementedError()
26+
27+
@abstractmethod
28+
def create_eval_set(self, app_name: str, eval_set_id: str):
29+
"""Creates an empty EvalSet given the app_name and eval_set_id."""
30+
raise NotImplementedError()
31+
32+
@abstractmethod
33+
def list_eval_sets(self, app_name: str) -> list[str]:
34+
"""Returns a list of EvalSets that belong to the given app_name."""
35+
raise NotImplementedError()
36+
37+
@abstractmethod
38+
def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: Any):
39+
"""Adds the given EvalCase to an existing EvalSet identified by app_name and eval_set_id."""
40+
raise NotImplementedError()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
import logging
17+
import os
18+
import re
19+
from typing import Any
20+
from typing_extensions import override
21+
from .eval_sets_manager import EvalSetsManager
22+
23+
logger = logging.getLogger(__name__)
24+
25+
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
26+
27+
28+
class LocalEvalSetsManager(EvalSetsManager):
29+
"""An EvalSets manager that stores eval sets locally on disk."""
30+
31+
def __init__(self, agent_dir: str):
32+
self._agent_dir = agent_dir
33+
34+
@override
35+
def get_eval_set(self, app_name: str, eval_set_id: str) -> Any:
36+
"""Returns an EvalSet identified by an app_name and eval_set_id."""
37+
# Load the eval set file data
38+
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
39+
with open(eval_set_file_path, "r") as file:
40+
return json.load(file) # Load JSON into a list
41+
42+
@override
43+
def create_eval_set(self, app_name: str, eval_set_id: str):
44+
"""Creates an empty EvalSet given the app_name and eval_set_id."""
45+
self._validate_id(id_name="Eval Set Id", id_value=eval_set_id)
46+
47+
# Define the file path
48+
new_eval_set_path = self._get_eval_set_file_path(app_name, eval_set_id)
49+
50+
logger.info("Creating eval set file `%s`", new_eval_set_path)
51+
52+
if not os.path.exists(new_eval_set_path):
53+
# Write the JSON string to the file
54+
logger.info("Eval set file doesn't exist, we will create a new one.")
55+
with open(new_eval_set_path, "w") as f:
56+
empty_content = json.dumps([], indent=2)
57+
f.write(empty_content)
58+
59+
@override
60+
def list_eval_sets(self, app_name: str) -> list[str]:
61+
"""Returns a list of EvalSets that belong to the given app_name."""
62+
eval_set_file_path = os.path.join(self._agent_dir, app_name)
63+
eval_sets = []
64+
for file in os.listdir(eval_set_file_path):
65+
if file.endswith(_EVAL_SET_FILE_EXTENSION):
66+
eval_sets.append(
67+
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
68+
)
69+
70+
return sorted(eval_sets)
71+
72+
@override
73+
def add_eval_case(self, app_name: str, eval_set_id: str, eval_case: Any):
74+
"""Adds the given EvalCase to an existing EvalSet identified by app_name and eval_set_id."""
75+
eval_case_id = eval_case["name"]
76+
self._validate_id(id_name="Eval Case Id", id_value=eval_case_id)
77+
78+
# Load the eval set file data
79+
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
80+
with open(eval_set_file_path, "r") as file:
81+
eval_set_data = json.load(file) # Load JSON into a list
82+
83+
if [x for x in eval_set_data if x["name"] == eval_case_id]:
84+
raise ValueError(
85+
f"Eval id `{eval_case_id}` already exists in `{eval_set_id}`"
86+
" eval set.",
87+
)
88+
89+
eval_set_data.append(eval_case)
90+
# Serialize the test data to JSON and write to the eval set file.
91+
with open(eval_set_file_path, "w") as f:
92+
f.write(json.dumps(eval_set_data, indent=2))
93+
94+
def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
95+
return os.path.join(
96+
self._agent_dir,
97+
app_name,
98+
eval_set_id + _EVAL_SET_FILE_EXTENSION,
99+
)
100+
101+
def _validate_id(self, id_name: str, id_value: str):
102+
pattern = r"^[a-zA-Z0-9_]+$"
103+
if not bool(re.fullmatch(pattern, id_value)):
104+
raise ValueError(
105+
f"Invalid {id_name}. {id_name} should have the `{pattern}` format",
106+
)

0 commit comments

Comments
 (0)
0