8000 Merge branch 'main' into support-async-tool-callbacks · tsayan/adk-python@926b0ef · GitHub
[go: up one dir, main page]

Skip to content

Commit 926b0ef

Browse files
Merge branch 'main' into support-async-tool-callbacks
2 parents fcbf574 + dbbeb19 commit 926b0ef

File tree

5 files changed

+163
-69
lines changed

5 files changed

+163
-69
lines changed

src/google/adk/cli/cli.py

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ class InputFile(BaseModel):
3939

4040
async def run_input_file(
4141
app_name: str,
42+
user_id: str,
4243
root_agent: LlmAgent,
4344
artifact_service: BaseArtifactService,
44-
session: Session,
4545
session_service: BaseSessionService,
4646
input_path: str,
47-
) -> None:
47+
) -> Session:
4848
runner = Runner(
4949
app_name=app_name,
5050
agent=root_agent,
@@ -55,33 +55,35 @@ async def run_input_file(
5555
input_file = InputFile.model_validate_json(f.read())
5656
input_file.state['_time'] = datetime.now()
5757

58-
session.state = input_file.state
58+
session = session_service.create_session(
59+
app_name=app_name, user_id=user_id, state=input_file.state
60+
)
5961
for query in input_file.queries:
60-
click.echo(f'user: {query}')
62+
click.echo(f'[user]: {query}')
6163
content = types.Content(role='user', parts=[types.Part(text=query)])
6264
async for event in runner.run_async(
6365
user_id=session.user_id, session_id=session.id, new_message=content
6466
):
6567
if event.content and event.content.parts:
6668
if text := ''.join(part.text or '' for part in event.content.parts):
6769
click.echo(f'[{event.author}]: {text}')
70+
return session
6871

6972

7073
async def run_interactively(
71-
app_name: str,
7274
root_agent: LlmAgent,
7375
artifact_service: BaseArtifactService,
7476
session: Session,
7577
session_service: BaseSessionService,
7678
) -> None:
7779
runner = Runner(
78-
app_name=app_name,
80+
app_name=session.app_name,
7981
agent=root_agent,
8082
artifact_service=artifact_service,
8183
session_service=session_service,
8284
)
8385
while True:
84-
query = input('user: ')
86+
query = input('[user]: ')
8587
if not query or not query.strip():
8688
continue
8789
if query == 'exit':
@@ -100,7 +102,8 @@ async def run_cli(
100102
*,
101103
agent_parent_dir: str,
102104
agent_folder_name: str,
103-
json_file_path: Optional[str] = None,
105+
input_file: Optional[str] = None,
106+
saved_session_file: Optional[str] = None,
104107
save_session: bool,
105108
) -> None:
106109
"""Runs an interactive CLI for a certain agent.
@@ -109,67 +112,71 @@ async def run_cli(
109112
agent_parent_dir: str, the absolute path of the parent folder of the agent
110113
folder.
111114
agent_folder_name: str, the name of the agent folder.
112-
json_file_path: Optional[str], the absolute path to the json file, either
113-
*.input.json or *.session.json.
115+
input_file: Optional[str], the absolute path to the json file that contains
116+
the initial session state and user queries, exclusive with
117+
saved_session_file.
118+
saved_session_file: Optional[str], the absolute path to the json file that
119+
contains a previously saved session, exclusive with input_file.
114120
save_session: bool, whether to save the session on exit.
115121
"""
116122
if agent_parent_dir not in sys.path:
117123
sys.path.append(agent_parent_dir)
118124

119125
artifact_service = InMemoryArtifactService()
120126
session_service = InMemorySessionService()
121-
session = session_service.create_session(
122-
app_name=agent_folder_name, user_id='test_user'
123-
)
124127

125128
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
126129
agent_module = importlib.import_module(agent_folder_name)
130+
user_id = 'test_user'
131+
session = session_service.create_session(
132+
app_name=agent_folder_name, user_id=user_id
133+
)
127134
root_agent = agent_module.agent.root_agent
128135
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
129-
if json_file_path:
130-
if json_file_path.endswith('.input.json'):
131-
await run_input_file(
132-
app_name=agent_folder_name,
133-
root_agent=root_agent,
134-
artifact_service=artifact_service,
135-
session=session,
136-
session_service=session_service,
137-
input_path=json_file_path,
138-
)
139-
elif json_file_path.endswith('.session.json'):
140-
with open(json_file_path, 'r') as f:
141-
session = Session.model_validate_json(f.read())
142-
for content in session.get_contents():
143-
if content.role == 'user':
144-
print('user: ', content.parts[0].text)
136+
if input_file:
137+
session = await run_input_file(
138+
app_name=agent_folder_name,
139+
user_id=user_id,
140+
root_agent=root_agent,
141+
artifact_service=artifact_service,
142+
session_service=session_service,
143+
input_path=input_file,
144+
)
145+
elif saved_session_file:
146+
147+
loaded_session = None
148+
with open(saved_session_file, 'r') as f:
149+
loaded_session = Session.model_validate_json(f.read())
150+
151+
if loaded_session:
152+
for event in loaded_session.events:
153+
session_service.append_event(session, event)
154+
content = event.content
155+
if not content or not content.parts or not content.parts[0].text:
156+
continue
157+
if event.author == 'user':
158+
click.echo(f'[user]: {content.parts[0].text}')
145159
else:
146-
print(content.parts[0].text)
147-
await run_interactively(
148-
agent_folder_name,
149-
root_agent,
150-
artifact_service,
151-
session,
152-
session_service,
153-
)
154-
else:
155-
print(f'Unsupported file type: {json_file_path}')
156-
exit(1)
160+
click.echo(f'[{event.author}]: {content.parts[0].text}')
161+
162+
await run_interactively(
163+
root_agent,
164+
artifact_service,
165+
session,
166+
session_service,
167+
)
157168
else:
158-
print(f'Running agent {root_agent.name}, type exit to exit.')
169+
click.echo(f'Running agent {root_agent.name}, type exit to exit.')
159170
await run_interactively(
160-
agent_folder_name,
161171
root_agent,
162172
artifact_service,
163173
session,
164174
session_service,
165175
)
166176

167177
if save_session:
168-
if json_file_path:
169-
session_path = json_file_path.replace('.input.json', '.session.json')
170-
else:
171-
session_id = input('Session ID to save: ')
172-
session_path = f'{agent_module_path}/{session_id}.session.json'
178+
session_id = input('Session ID to save: ')
179+
session_path = f'{agent_module_path}/{session_id}.session.json'
173180

174181
# Fetch the session again to get all the details.
175182
session = session_service.get_session(

src/google/adk/cli/cli_tools_click.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,23 @@ def cli_create_cmd(
9696
)
9797

9898

99+
def validate_exclusive(ctx, param, value):
100+
# Store the validated parameters in the context
101+
if not hasattr(ctx, "exclusive_opts"):
102+
ctx.exclusive_opts = {}
103+
104+
# If this option has a value and we've already seen another exclusive option
105+
if value is not None and any(ctx.exclusive_opts.values()):
106+
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
107+
raise click.UsageError(
108+
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
109+
)
110+
111+
# Record this option's value
112+
ctx.exclusive_opts[param.name] = value is not None
113+
return value
114+
115+
99116
@main.command("run")
100117
@click.option(
101118
"--save_session",
@@ -105,13 +122,43 @@ def cli_create_cmd(
105122
default=False,
106123
help="Optional. Whether to save the session to a json file on exit.",
107124
)
125+
@click.option(
126+
"--replay",
127+
type=click.Path(
128+
exists=True, dir_okay=False, file_okay=True, resolve_path=True
129+
),
130+
help=(
131+
"The json file that contains the initial state of the session and user"
132+
" queries. A new session will be created using this state. And user"
133+
" queries are run againt the newly created session. Users cannot"
134+
" continue to interact with the agent."
135+
),
136+
callback=validate_exclusive,
137+
)
138+
@click.option(
139+
"--resume",
140+
type=click.Path(
141+
exists=True, dir_okay=False, file_okay=True, resolve_path=True
142+
),
143+
help=(
144+
"The json file that contains a previously saved session (by"
145+
"--save_session option). The previous session will be re-displayed. And"
146+
" user can continue to interact with the agent."
147+
),
148+
callback=validate_exclusive,
149+
)
108150
@click.argument(
109151
"agent",
110152
type=click.Path(
111153
exists=True, dir_okay=True, file_okay=False, resolve_path=True
112154
),
113155
)
114-
def cli_run(agent: str, save_session: bool):
156+
def cli_run(
157+
agent: str,
158+
save_session: bool,
159+
replay: Optional[str],
160+
resume: Optional[str],
161+
):
115162
"""Runs an interactive CLI for a certain agent.
116163
117164
AGENT: The path to the agent source code folder.
@@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
129176
run_cli(
130177
agent_parent_dir=agent_parent_folder,
131178
agent_folder_name=agent_folder_name,
179+
input_file=replay,
180+
saved_session_file=resume,
132181
save_session=save_session,
133182
)
134183
)

src/google/adk/events/event_actions.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,13 @@ class EventActions(BaseModel):
4848
"""The agent is escalating to a higher level agent."""
4949

5050
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
51-
"""Will only be set by a tool response indicating tool request euc.
52-
dict key is the function call id since one function call response (from model)
53-
could correspond to multiple function calls.
54-
dict value is the required auth config.
51+
"""Authentication configurations requested by tool responses.
52+
53+
This field will only be set by a tool response event indicating tool request
54+
auth credential.
55+
- Keys: The function call id. Since one function response event could contain
56+
multiple function responses that correspond to multiple function calls. Each
57+
function call could request different auth configs. This id is used to
58+
identify the function call.
59+
- Values: The requested auth config.
5560
"""

src/google/adk/sessions/database_session_service.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858

5959
logger = logging.getLogger(__name__)
6060

61+
DEFAULT_MAX_VARCHAR_LENGTH = 256
62+
6163

6264
class DynamicJSON(TypeDecorator):
6365
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
@@ -92,17 +94,25 @@ def process_result_value(self, value, dialect: Dialect):
9294

9395
class Base(DeclarativeBase):
9496
"""Base class for database tables."""
97+
9598
pass
9699

97100 10000

98101
class StorageSession(Base):
99102
"""Represents a session stored in the database."""
103+
100104
__tablename__ = "sessions"
101105

102-
app_name: Mapped[str] = mapped_column(String, primary_key=True)
103-
user_id: Mapped[str] = mapped_column(String, primary_key=True)
106+
app_name: Mapped[str] = mapped_column(
107+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
108+
)
109+
user_id: Mapped[str] = mapped_column(
110+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
111+
)
104112
id: Mapped[str] = mapped_column(
105-
String, primary_key=True, default=lambda: str(uuid.uuid4())
113+
String(DEFAULT_MAX_VARCHAR_LENGTH),
114+
primary_key=True,
115+
default=lambda: str(uuid.uuid4()),
106116
)
107117

108118
state: Mapped[MutableDict[str, Any]] = mapped_column(
@@ -125,16 +135,27 @@ def __repr__(self):
125135

126136
class StorageEvent(Base):
127137
"""Represents an event stored in the database."""
138+
128139
__tablename__ = "events"
129140

130-
id: Mapped[str] = mapped_column(String, primary_key=True)
131-
app_name: Mapped[str] = mapped_column(String, primary_key=True)
132-
user_id: Mapped[str] = mapped_column(String, primary_key=True)
133-
session_id: Mapped[str] = mapped_column(String, primary_key=True)
141+
id: Mapped[str] = mapped_column(
142+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
143+
)
144+
app_name: Mapped[str] = mapped_column(
145+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
146+
)
147+
user_id: Mapped[str] = mapped_column(
148+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
149+
)
150+
session_id: Mapped[str] = mapped_column(
151+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
152+
)
134153

135-
invocation_id: Mapped[str] = mapped_column(String)
136-
author: Mapped[str] = mapped_column(String)
137-
branch: Mapped[str] = mapped_column(String, nullable=True)
154+
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
155+
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
156+
branch: Mapped[str] = mapped_column(
157+
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
158+
)
138159
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
139160
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
140161
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
@@ -147,8 +168,10 @@ class StorageEvent(Base):
147168
)
148169
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
149170
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
150-
error_code: Mapped[str] = mapped_column(String, nullable=True)
151-
error_message: Mapped[str] = mapped_column(String, nullable=True)
171+
error_code: Mapped[str] = mapped_column(
172+
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
173+
)
174+
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
152175
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
153176

154177
storage_session: Mapped[StorageSession] = relationship(
@@ -182,23 +205,33 @@ def long_running_tool_ids(self, value: set[str]):
182205

183206
class StorageAppState(Base):
184207
"""Represents an app state stored in the database."""
208+
185209
__tablename__ = "app_states"
186210

187-
app_name: Mapped[str] = mapped_column(String, primary_key=True)
211+
app_name: Mapped[str] = mapped_column(
212+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
213+
)
188214
state: Mapped[MutableDict[str, Any]] = mapped_column(
189215
MutableDict.as_mutable(DynamicJSON), default={}
190216
)
191217
update_time: Mapped[DateTime] = mapped_column(
192218
DateTime(), default=func.now(), onupdate=func.now()
193219
)
194220

195-< A413 /span>
196221
class StorageUserState(Base):
197222
"""Represents a user state stored in the database."""
223+
198224
__tablename__ = "user_states"
199225

200-
app_name: Mapped[str] = mapped_column(String, primary_key=True)
201-
user_id: Mapped[str] = mapped_column(String, primary_key=True)
226+
app_name: Mapped[str] = mapped_column(
227+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
228+
)
229+
user_id: Mapped[str] = mapped_column(
230+
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
231+
)
232+
state: Mapped[MutableDict[str, Any]] = mapped_column(
233+
MutableDict.as_mutable(DynamicJSON), default={}
234+
)
202235
state: Mapped[MutableDict[str, Any]] = mapped_column(
203236
MutableDict.as_mutable(DynamicJSON), default={}
204237
)

0 commit comments

Comments
 (0)
0