8000 feat: add memory_service option to CLI · devevignesh/adk-python@416dc6f · GitHub
[go: up one dir, main page]

Skip to content

Commit 416dc6f

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: add memory_service option to CLI
chore: consolidate ADK service CLI options PiperOrigin-RevId: 769944881
1 parent 9df3f72 commit 416dc6f

File tree

5 files changed

+218
-70
lines changed

5 files changed

+218
-70
lines changed

src/google/adk/cli/cli_deploy.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
5656
EXPOSE {port}
5757
58-
CMD adk {command} --port={port} {host_option} {session_db_option} {trace_to_cloud_option} "/app/agents"
58+
CMD adk {command} --port={port} {host_option} {service_option} {trace_to_cloud_option} "/app/agents"
5959
"""
6060

6161
_AGENT_ENGINE_APP_TEMPLATE = """
@@ -84,6 +84,32 @@ def _resolve_project(project_in_option: Optional[str]) -> str:
8484
return project
8585

8686

87+
def _get_service_option_by_adk_version(
88+
adk_version: str,
89+
session_uri: Optional[str],
90+
artifact_uri: Optional[str],
91+
memory_uri: Optional[str],
92+
) -> str:
93+
"""Returns service option string based on adk_version."""
94+
if adk_version >= '1.3.0':
95+
session_option = (
96+
f'--session_service_uri={session_uri}' if session_uri else ''
97+
)
98+
artifact_option = (
99+
f'--artifact_service_uri={artifact_uri}' if artifact_uri else ''
100+
)
101+
memory_option = f'--memory_service_uri={memory_uri}' if memory_uri else ''
102+
return f'{session_option} {artifact_option} {memory_option}'
103+
elif adk_version >= '1.2.0':
104+
session_option = f'--session_db_url={session_uri}' if session_uri else ''
105+
artifact_option = (
106+
f'--artifact_storage_uri={artifact_uri}' if artifact_uri else ''
107+
)
108+
return f'{session_option} {artifact_option}'
109+
else:
110+
return f'--session_db_url={session_uri}' if session_uri else ''
111+
112+
87113
def to_cloud_run(
88114
*,
89115
agent_folder: str,
@@ -96,9 +122,10 @@ def to_cloud_run(
96122
trace_to_cloud: bool,
97123
with_ui: bool,
98124
verbosity: str,
99-
session_db_url: str,
100-
artifact_storage_uri: Optional[str],
101125
adk_version: str,
126+
session_service_uri: Optional[str] = None,
127+
artifact_service_uri: Optional[str] = None,
128+
memory_service_uri: Optional[str] = None,
102129
):
103130
"""Deploys an agent to Google Cloud Run.
104131
@@ -126,9 +153,10 @@ def to_cloud_run(
126153
trace_to_cloud: Whether to enable Cloud Trace.
127154
with_ui: Whether to deploy with UI.
128155
verbosity: The verbosity level of the CLI.
129-
session_db_url: The database URL to connect the session.
130-
artifact_storage_uri: The artifact storage URI to store the artifacts.
131156
adk_version: The ADK version to use in Cloud Run.
157+
session_service_uri: The URI of the session service.
158+
artifact_service_uri: The URI of the artifact service.
159+
memory_service_uri: The URI of the memory service.
132160
"""
133161
app_name = app_name or os.path.basename(agent_folder)
134162

@@ -162,12 +190,12 @@ def to_cloud_run(
162190
port=port,
163191
command='web' if with_ui else 'api_server',
164192
install_agent_deps=install_agent_deps,
165-
session_db_option=f'--session_db_url={session_db_url}'
166-
if session_db_url
167-
else '',
168-
artifact_storage_option=f'--artifact_storage_uri={artifact_storage_uri}'
169-
if artifact_storage_uri
170-
else '',
193+
service_option=_get_service_option_by_adk_version(
194+
adk_version,
195+
session_service_uri,
196+
artifact_service_uri,
197+
memory_service_uri,
198+
),
171199
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
172200
adk_version=adk_version,
173201
host_option=host_option,

src/google/adk/cli/cli_tools_click.py

Lines changed: 104 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -417,28 +417,87 @@ async def _collect_eval_results() -> list[EvalCaseResult]:
417417
print(eval_result.model_dump_json(indent=2))
418418

419419

420-
def fast_api_common_options():
421-
"""Decorator to add common fast api options to click commands."""
420+
def adk_services_options():
421+
"""Decorator to add ADK services options to click commands."""
422422

423423
def decorator(func):
424424
@click.option(
425-
"--session_db_url",
425+
"--session_service_uri",
426426
help=(
427-
"""Optional. The database URL to store the session.
427+
"""Optional. The URI of the session service.
428428
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
429429
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
430-
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
430+
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported database URIs."""
431431
),
432432
)
433433
@click.option(
434-
"--artifact_storage_uri",
434+
"--artifact_service_uri",
435435
type=str,
436436
help=(
437-
"Optional. The artifact storage URI to store the artifacts,"
437+
"Optional. The URI of the artifact service,"
438438
" supported URIs: gs://<bucket name> for GCS artifact service."
439439
),
440440
default=None,
441441
)
442+
@click.option(
443+
"--memory_service_uri",
444+
type=str,
445+
help=(
446+
"""Optional. The URI of the memory service.
447+
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service."""
448+
),
449+
default=None,
450+
)
451+
@functools.wraps(func)
452+
def wrapper(*args, **kwargs):
453+
return func(*args, **kwargs)
454+
455+
return wrapper
456+
457+
return decorator
458+
459+
460+
def deprecated_adk_services_options():
461+
"""Depracated ADK services options."""
462+
463+
def warn(alternative_param, ctx, param, value):
464+
if value:
465+
click.echo(
466+
click.style(
467+
f"WARNING: Deprecated option {param.name} is used. Please use"
468+
f" {alternative_param} instead.",
469+
fg="yellow",
470+
),
471+
err=True,
472+
)
473+
return value
474+
475+
def decorator(func):
476+
@click.option(
477+
"--session_db_url",
478+
help="Deprecated. Use --session_service_uri instead.",
479+
callback=functools.partial(warn, "--session_service_uri"),
480+
)
481+
@click.option(
482+
"--artifact_storage_uri",
483+
type=str,
484+
help="Deprecated. Use --artifact_service_uri instead.",
485+
callback=functools.partial(warn, "--artifact_service_uri"),
486+
default=None,
487+
)
488+
@functools.wraps(func)
489+
def wrapper(*args, **kwargs):
490+
return func(*args, **kwargs)
491+
492+
return wrapper
493+
494+
return decorator
495+
496+
497+
def fast_api_common_options():
498+
"""Decorator to add common fast api options to click commands."""
499+
500+
def decorator(func):
442501
@click.option(
443502
"--host",
444503
type=str,
@@ -489,6 +548,8 @@ def wrapper(*args, **kwargs):
489548

490549
@main.command("web")
491550
@fast_api_common_options()
551+
@adk_services_options()
552+
@deprecated_adk_services_options()
492553
@click.argument(
493554
"agents_dir",
494555
type=click.Path(
@@ -498,14 +559,17 @@ def wrapper(*args, **kwargs):
498559
)
499560
def cli_web(
500561
agents_dir: str,
501-
session_db_url: str = "",
502-
artifact_storage_uri: Optional[str] = None,
503562
log_level: str = "INFO",
504563
allow_origins: Optional[list[str]] = None,
505564
host: str = "127.0.0.1",
506565
port: int = 8000,
507566
trace_to_cloud: bool = False,
508567
reload: bool = True,
568+
session_service_uri: Optional[str] = None,
569+
artifact_service_uri: Optional[str] = None,
570+
memory_service_uri: Optional[str] = None,
571+
session_db_url: Optional[str] = None, # Deprecated
572+
artifact_storage_uri: Optional[str] = None, # Deprecated
509573
):
510574
"""Starts a FastAPI server with Web UI for agents.
511575
@@ -514,7 +578,7 @@ def cli_web(
514578
515579
Example:
516580
517-
adk web --session_db_url=[db_url] --port=[port] path/to/agents_dir
581+
adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir
518582
"""
519583
logs.setup_adk_logger(getattr(logging, log_level.upper()))
520584

@@ -540,10 +604,13 @@ async def _lifespan(app: FastAPI):
540604
fg="green",
541605
)
542606

607+
session_service_uri = session_service_uri or session_db_url
608+
artifact_service_uri = artifact_service_uri or artifact_storage_uri
543609
app = get_fast_api_app(
544610
agents_dir=agents_dir,
545-
session_db_url=session_db_url,
546-
artifact_storage_uri=artifact_storage_uri,
611+
session_service_uri=session_service_uri,
612+
artifact_service_uri=artifact_service_uri,
613+
memory_service_uri=memory_service_uri,
547614
allow_origins=allow_origins,
548615
web=True,
549616
trace_to_cloud=trace_to_cloud,
@@ -571,16 +638,21 @@ async def _lifespan(app: FastAPI):
571638
default=os.getcwd(),
572639
)
573640
@fast_api_common_options()
641+
@adk_services_options()
642+
@deprecated_adk_services_options()
574643
def cli_api_server(
575644
agents_dir: str,
576-
session_db_url: str = "",
577-
artifact_storage_uri: Optional[str] = None,
578645
log_level: str = "INFO",
579646
allow_origins: Optional[list[str]] = None,
580647
host: str = "127.0.0.1",
581648
port: int = 8000,
582649
trace_to_cloud: bool = False,
583650
reload: bool = True,
651+
session_service_uri: Optional[str] = None,
652+
artifact_service_uri: Optional[str] = None,
653+
memory_service_uri: Optional[str] = None,
654+
session_db_url: Optional[str] = None, # Deprecated
655+
artifact_storage_uri: Optional[str] = None, # Deprecated
584656
):
585657
"""Starts a FastAPI server for agents.
586658
@@ -589,15 +661,18 @@ def cli_api_server(
589661
590662
Example:
591663
592-
adk api_server --session_db_url=[db_url] --port=[port] path/to/agents_dir
664+
adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir
593665
"""
594666
logs.setup_adk_logger(getattr(logging, log_level.upper()))
595667

668+
session_service_uri = session_service_uri or session_db_url
669+
artifact_service_uri = artifact_service_uri or artifact_storage_uri
596670
config = uvicorn.Config(
597671
get_fast_api_app(
598672
agents_dir=agents_dir,
599-
session_db_url=session_db_url,
600-
artifact_storage_uri=artifact_storage_uri,
673+
session_service_uri=session_service_uri,
674+
artifact_service_uri=artifact_service_uri,
675+
memory_service_uri=memory_service_uri,
601676
allow_origins=allow_origins,
602677
web=False,
603678
trace_to_cloud=trace_to_cloud,
@@ -689,27 +764,6 @@ def cli_api_server(
689764
default="WARNING",
690765
help="Optional. Override the default verbosity level.",
691766
)
692-
@click.option(
693-
"--session_db_url",
694-
help=(
695-
"""Optional. The database URL to store the session.
696-
697-
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
698-
699-
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
700-
701-
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
702-
),
703-
)
704-
@click.option(
705-
"--artifact_storage_uri",
706-
type=str,
707-
help=(
708-
"Optional. The artifact storage URI to store the artifacts, supported"
709-
" URIs: gs://<bucket name> for GCS artifact service."
710-
),
711-
default=None,
712-
)
713767
@click.argument(
714768
"agent",
715769
type=click.Path(
@@ -726,6 +780,8 @@ def cli_api_server(
726780
" version in the dev environment)"
727781
),
728782
)
783+
@adk_services_options()
784+
@deprecated_adk_services_options()
729785
def cli_deploy_cloud_run(
730786
agent: str,
731787
project: Optional[str],
@@ -737,9 +793,12 @@ def cli_deploy_cloud_run(
737793
trace_to_cloud: bool,
738794
with_ui: bool,
739795
verbosity: str,
740-
session_db_url: str,
741-
artifact_storage_uri: Optional[str],
742796
adk_version: str,
797+
session_service_uri: Optional[str] = None,
798+
artifact_service_uri: Optional[str] = None,
799+
memory_service_uri: Optional[str] = None,
800+
session_db_url: Optional[str] = None, # Deprecated
801+
artifact_storage_uri: Optional[str] = None, # Deprecated
743802
):
744803
"""Deploys an agent to Cloud Run.
745804
@@ -749,6 +808,8 @@ def cli_deploy_cloud_run(
749808
750809
adk deploy cloud_run --project=[project] --region=[region] path/to/my_agent
751810
"""
811+
session_service_uri = session_service_uri or session_db_url
812+
artifact_service_uri = artifact_service_uri or artifact_storage_uri
752813
try:
753814
cli_deploy.to_cloud_run(
754815
agent_folder=agent,
@@ -761,9 +822,10 @@ def cli_deploy_cloud_run(
761822
trace_to_cloud=trace_to_cloud,
762823
with_ui=with_ui,
763824
verbosity=verbosity,
764-
session_db_url=session_db_url,
765-
artifact_storage_uri=artifact_storage_uri,
766825
adk_version=adk_version,
826+
session_service_uri=session_service_uri,
827+
artifact_service_uri=artifact_service_uri,
828+
memory_service_uri=memory_service_uri,
767829
)
768830
except Exception as e:
769831
click.secho(f"Deploy failed: {e}", fg="red", err=True)

0 commit comments

Comments
 (0)
0