8000 feat: Add an option to use gcs artifact service in adk web. · nag763/adk-python@8d36dbd · GitHub
[go: up one dir, main page]

Skip to content

Commit 8d36dbd

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Add an option to use gcs artifact service in adk web.
Resolves google#309 PiperOrigin-RevId: 765772763
1 parent 0e72efb commit 8d36dbd

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

src/google/adk/cli/cli_deploy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
import os
1617
import shutil
@@ -86,6 +87,7 @@ def to_cloud_run(
8687
with_ui: bool,
8788
verbosity: str,
8889
session_db_url: str,
90+
artifact_storage_uri: Optional[str],
8991
adk_version: str,
9092
):
9193
"""Deploys an agent to Google Cloud Run.
@@ -115,6 +117,7 @@ def to_cloud_run(
115117
with_ui: Whether to deploy with UI.
116118
verbosity: The verbosity level of the CLI.
117119
session_db_url: The database URL to connect the session.
120+
artifact_storage_uri: The artifact storage URI to store the artifacts.
118121
adk_version: The ADK version to use in Cloud Run.
119122
"""
120123
app_name = app_name or os.path.basename(agent_folder)
@@ -152,6 +155,9 @@ def to_cloud_run(
152155
session_db_option=f'--session_db_url={session_db_url}'
153156
if session_db_url
154157
else '',
158+
artifact_storage_option=f'--artifact_storage_uri={artifact_storage_uri}'
159+
if artifact_storage_uri
160+
else '',
155161
trace_to_cloud_option='--trace_to_cloud' if trace_to_cloud else '',
156162
adk_version=adk_version,
157163
host_option=host_option,

src/google/adk/cli/cli_tools_click.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,15 @@ def decorator(func):
430430
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
431431
),
432432
)
433+
@click.option(
434+
"--artifact_storage_uri",
435+
type=str,
436+
help=(
437+
"Optional. The artifact storage URI to store the artifacts,"
438+
8000 " supported URIs: gs://<bucket name> for GCS artifact service."
439+
),
440+
default=None,
441+
)
433442
@click.option(
434443
"--host",
435444
type=str,
@@ -490,6 +499,7 @@ def wrapper(*args, **kwargs):
490499
def cli_web(
491500
agents_dir: str,
492501
session_db_url: str = "",
502+
artifact_storage_uri: Optional[str] = None,
493503
log_level: str = "INFO",
494504
allow_origins: Optional[list[str]] = None,
495505
host: str = "127.0.0.1",
@@ -533,6 +543,7 @@ async def _lifespan(app: FastAPI):
533543
app = get_fast_api_app(
534544
agents_dir=agents_dir,
535545
session_db_url=session_db_url,
546+
artifact_storage_uri=artifact_storage_uri,
536547
allow_origins=allow_origins,
537548
web=True,
538549
trace_to_cloud=trace_to_cloud,
@@ -563,6 +574,7 @@ async def _lifespan(app: FastAPI):
563574
def cli_api_server(
564575
agents_dir: str,
565576
session_db_url: str = "",
577+
artifact_storage_uri: Optional[str] = None,
566578
log_level: str = "INFO",
567579
allow_origins: Optional[list[str]] = None,
568580
host: str = "127.0.0.1",
@@ -585,6 +597,7 @@ def cli_api_server(
585597
get_fast_api_app(
586598
agents_dir=agents_dir,
587599
session_db_url=session_db_url,
600+
artifact_storage_uri=artifact_storage_uri,
588601
allow_origins=allow_origins,
589602
web=False,
590603
trace_to_cloud=trace_to_cloud,
@@ -688,6 +701,15 @@ def cli_api_server(
688701
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported DB URLs."""
689702
),
690703
)
704+
@click.option(
705+
"--artifact_storage_uri",
706+
type=str,
707+
help=(
708+
"Optional. The artifact storage URI to store the artifacts, supported"
709+
" URIs: gs://<bucket name> for GCS artifact service."
710+
),
711+
default=None,
712+
)
691713
@click.argument(
692714
"agent",
693715
type=click.Path(
@@ -716,6 +738,7 @@ def cli_deploy_cloud_run(
716738
with_ui: bool,
717739
verbosity: str,
718740
session_db_url: str,
741+
artifact_storage_uri: Optional[str],
719742
adk_version: str,
720743
):
721744
"""Deploys an agent to Cloud Run.
@@ -739,6 +762,7 @@ def cli_deploy_cloud_run(
739762
with_ui=with_ui,
740763
verbosity=verbosity,
741764
session_db_url=session_db_u 27AE rl,
765+
artifact_storage_uri=artifact_storage_uri,
742766
adk_version=adk_version,
743767
)
744768
except Exception as e:

src/google/adk/cli/fast_api.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
from __future__ import annotations
1716

1817
import asyncio
@@ -56,6 +55,7 @@
5655
from ..agents.live_request_queue import LiveRequestQueue
5756
from ..agents.llm_agent import Agent
5857
from ..agents.run_config import StreamingMode
58+
from ..artifacts.gcs_artifact_service import GcsArtifactService
5959
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
6060
from ..evaluation.eval_case import EvalCase
6161
from ..evaluation.eval_case import SessionInput
@@ -193,6 +193,7 @@ def get_fast_api_app(
193193
*,
194194
agents_dir: str,
195195
session_db_url: str = "",
196+
artifact_storage_uri: Optional[str] = None,
196197
allow_origins: Optional[list[str]] = None,
197198
web: bool,
198199
trace_to_cloud: bool = False,
@@ -251,13 +252,12 @@ async def internal_lifespan(app: FastAPI):
251252

252253
runner_dict = {}
253254

254-
# Build the Artifact service
255-
artifact_service = InMemoryArtifactService()
256-
memory_service = InMemoryMemoryService()
257-
258255
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
259256
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
260257

258+
# Build the Memory service
259+
memory_service = InMemoryMemoryService()
260+
261261
# Build the Session service
262262
agent_engine_id = ""
263263
if session_db_url:
@@ -276,6 +276,18 @@ async def internal_lifespan(app: FastAPI):
276276
else:
277277
session_service = InMemorySessionService()
278278

279+
# Build the Artifact service
280+
if artifact_storage_uri:
281+
if artifact_storage_uri.startswith("gs://"):
282+
gcs_bucket = artifact_storage_uri.split("://")[1]
283+
artifact_service = GcsArtifactService(bucket_name=gcs_bucket)
284+
else:
285+
raise click.ClickException(
286+
"Unsupported artifact storage URI: %s" % artifact_storage_uri
287+
)
288+
else:
289+
artifact_service = InMemoryArtifactService()
290+
279291
# initialize Agent Loader
280292
agent_loader = AgentLoader(agents_dir)
281293

tests/unittests/cli/utils/test_cli_deploy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def _recording_copytree(*args: Any, **kwargs: Any):
128128
with_ui=True,
129129
verbosity="info",
130130
session_db_url="sqlite://",
131+
artifact_storage_uri="gs://bucket",
131132
adk_version="0.0.5",
132133
)
133134

@@ -170,6 +171,7 @@ def _fake_rmtree(path: str | Path, *a: Any, **k: Any) -> None:
170171
with_ui=False,
171172
verbosity="info",
172173
session_db_url=None,
174+
artifact_storage_uri=None,
173175
adk_version="0.0.5",
174176
)
175177

0 commit comments

Comments
 (0)
0