8000 Merge branch 'main' into docs/update-streamable-http-docstring · bl3ck/adk-python@e8cfaa9 · GitHub
[go: up one dir, main page]

Skip to content

Commit e8cfaa9

Browse files
authored
Merge branch 'main' into docs/update-streamable-http-docstring
2 parents 5129974 + a17ebe6 commit e8cfaa9

24 files changed

+1479
-114
lines changed

contributing/samples/quickstart/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_weather(city: str) -> dict:
2929
"status": "success",
3030
"report": (
3131
"The weather in New York is sunny with a temperature of 25 degrees"
32-
" Celsius (41 degrees Fahrenheit)."
32+
" Celsius (77 degrees Fahrenheit)."
3333
),
3434
}
3535
else:

src/google/adk/agents/invocation_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pydantic import ConfigDict
2323

2424
from ..artifacts.base_artifact_service import BaseArtifactService
25+
from ..auth.credential_service.base_credential_service import BaseCredentialService
2526
from ..memory.base_memory_service import BaseMemoryService
2627
from ..sessions.base_session_service import BaseSessionService
2728
from ..sessions.session import Session
@@ -115,6 +116,7 @@ class InvocationContext(BaseModel):
115116
artifact_service: Optional[BaseArtifactService] = None
116117
session_service: BaseSessionService
117118
memory_service: Optional[BaseMemoryService] = None
119+
credential_service: Optional[BaseCredentialService] = None
118120

119121
invocation_id: str
120122
"""The id of this invocation context. Readonly."""

src/google/adk/auth/credential_service/base_credential_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from typing import Optional
2020

2121
from ...tools.tool_context import ToolContext
22-
from ...utils.feature_decorator import working_in_progress
22+
from ...utils.feature_decorator import experimental
2323
from ..auth_credential import AuthCredential
2424
from ..auth_tool import AuthConfig
2525

2626

27-
@working_in_progress("Implementation are in progress. Don't use it for now.")
27+
@experimental
2828
class BaseCredentialService(ABC):
2929
"""Abstract class for Service that loads / saves tool credentials from / to
3030
the backend credential store."""
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Optional
18+
19+
from typing_extensions import override
20+
21+
from ...tools.tool_context import ToolContext
22+
from ...utils.feature_decorator import experimental
23+
from ..auth_credential import AuthCredential
24+
from ..auth_tool import AuthConfig
25+
from .base_credential_service import BaseCredentialService
26+
27+
28+
@experimental
29+
class InMemoryCredentialService(BaseCredentialService):
30+
"""Class for in memory implementation of credential service(Experimental)"""
31+
32+
def __init__(self):
33+
super().__init__()
34+
self._credentials = {}
35+
36+
@override
37+
async def load_credential(
38+
self,
39+
auth_config: AuthConfig,
40+
tool_context: ToolContext,
41+
) -> Optional[AuthCredential]:
42+
credential_bucket = self._get_bucket_for_current_context(tool_context)
43+
return credential_bucket.get(auth_config.credential_key)
44+
45+
@override
46+
async def save_credential(
47+
self,
48+
auth_config: AuthConfig,
49+
tool_context: ToolContext,
50+
) -> None:
51+
credential_bucket = self._get_bucket_for_current_context(tool_context)
52+
credential_bucket[auth_config.credential_key] = (
53+
auth_config.exchanged_auth_credential
54+
)
55+
56+
def _get_bucket_for_current_context(self, tool_context: ToolContext) -> str:
57+
app_name = tool_context._invocation_context.app_name
58+
user_id = tool_context._invocation_context.user_id
59+
60+
if app_name not in self._credentials:
61+
self._credentials[app_name] = {}
62+
if user_id not in self._credentials[app_name]:
63+
self._credentials[app_name][user_id] = {}
64+
return self._credentials[app_name][user_id]

src/google/adk/auth/exchanger/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
"""Credential exchanger module."""
1616

1717
from .base_credential_exchanger import BaseCredentialExchanger
18-
from .credential_exchanger_registry import CredentialExchangerRegistry
1918
from .service_account_credential_exchanger import ServiceAccountCredentialExchanger
2019

2120
__all__ = [
2221
"BaseCredentialExchanger",
23-
"CredentialExchangerRegistry",
2422
"ServiceAccountCredentialExchanger",
2523
]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Credential exchanger registry."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Dict
20+
from typing import Optional
21+
22+
from ...utils.feature_decorator import experimental
23+
from ..auth_credential import AuthCredentialTypes
24+
from .base_credential_exchanger import BaseCredentialExchanger
25+
26+
27+
@experimental
28+
class CredentialExchangerRegistry:
29+
"""Registry for credential exchanger instances."""
30+
31+
def __init__(self):
32+
self._exchangers: Dict[AuthCredentialTypes, BaseCredentialExchanger] = {}
33+
34+
def register(
35+
self,
36+
credential_type: AuthCredentialTypes,
37+
exchanger_instance: BaseCredentialExchanger,
38+
) -> None:
39+
"""Register an exchanger instance for a credential type.
40+
41+
Args:
42+
credential_type: The credential type to register for.
43+
exchanger_instance: The exchanger instance to register.
44+
"""
45+
self._exchangers[credential_type] = exchanger_instance
46+
47+
def get_exchanger(
48+
self, credential_type: AuthCredentialTypes
49+
) -> Optional[BaseCredentialExchanger]:
50+
"""Get the exchanger instance for a credential type.
51+
52+
Args:
53+
credential_type: The credential type to get exchanger for.
54+
55+
Returns:
56+
The exchanger instance if registered, None otherwise.
57+
"""
58+
return self._exchangers.get(credential_type)

src/google/adk/auth/service_account_credential_exchanger.py renamed to src/google/adk/auth/exchanger/service_account_credential_exchanger.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,64 +16,79 @@
1616

1717
from __future__ import annotations
1818

19+
from typing import Optional
20+
1921
import google.auth
2022
from google.auth.transport.requests import Request
2123
from google.oauth2 import service_account
24+
from typing_extensions import override
2225

23-
from ..utils.feature_decorator import experimental
24-
from .auth_credential import AuthCredential
25-
from .auth_credential import AuthCredentialTypes
26-
from .auth_credential import HttpAuth
27-
from .auth_credential import HttpCredentials
26+
from ...utils.feature_decorator import experimental
27+
from ..auth_credential import AuthCredential
28+
from ..auth_credential import AuthCredentialTypes
29+
from ..auth_schemes import AuthScheme
30+
from .base_credential_exchanger import BaseCredentialExchanger
2831

2932

3033
@experimental
31-
class ServiceAccountCredentialExchanger:
34+
class ServiceAccountCredentialExchanger(BaseCredentialExchanger):
3235
"""Exchanges Google Service Account credentials for an access token.
3336
3437
Uses the default service credential if `use_default_credential = True`.
3538
Otherwise, uses the service account credential provided in the auth
3639
credential.
3740
"""
3841

39-
def __init__(self, credential: AuthCredential):
40-
if credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT:
41-
raise ValueError("Credential is not a service account credential.")
42-
self._credential = credential
43-
44-
def exchange(self) -> AuthCredential:
42+
@override
43+
async def exchange(
44+
self,
45+
auth_credential: AuthCredential,
46+
auth_scheme: Optional[AuthScheme] = None,
47+
) -> AuthCredential:
4548
"""Exchanges the service account auth credential for an access token.
4649
4750
If the AuthCredential contains a service account credential, it will be used
4851
to exchange for an access token. Otherwise, if use_default_credential is True,
4952
the default application credential will be used for exchanging an access token.
5053
54+
Args:
55+
auth_scheme: The authentication scheme.
56+
auth_credential: The credential to exchange.
57+
5158
Returns:
52-
An AuthCredential in HTTP Bearer format, containing the access token.
59+
An AuthCredential in OAUTH2 format, containing the exchanged credential JSON.
5360
5461
Raises:
5562
ValueError: If service account credentials are missing or invalid.
5663
Exception: If credential exchange or refresh fails.
5764
"""
65+
if auth_credential is None:
66+
raise ValueError("Credential cannot be None.")
67+
68+
if auth_credential.auth_type != AuthCredentialTypes.SERVICE_ACCOUNT:
69+
raise ValueError("Credential is not a service account credential.")
70+
71+
if auth_credential.service_account is None:
72+
raise ValueError(
73+
"Service account credentials are missing. Please provide them."
74+
)
75+
5876
if (
59-
self._credential is None
60-
or self._credential.service_account is None
61-
or (
62-
self._credential.service_account.service_account_credential is None
63-
and not self._credential.service_account.use_default_credential
64-
)
77+
auth_credential.service_account.service_account_credential is None
78+
and not auth_credential.service_account.use_default_credential
6579
):
6680
raise ValueError(
67-
"Service account credentials are missing. Please provide them, or set"
68-
" `use_default_credential = True` to use application default"
69-
" credential in a hosted service like Google Cloud Run."
81+
"Service account credentials are invalid. Please set the"
82+
" service_account_credential field or set `use_default_credential ="
83+
" True` to use application default credential in a hosted service"
84+
" like Google Cloud Run."
7085
)
7186

7287
try:
73-
if self._credential.service_account.use_default_credential:
88+
< F438 span class="pl-k">if auth_credential.service_account.use_default_credential:
7489
credentials, _ = google.auth.default()
7590
else:
76-
config = self._credential.service_account
91+
config = auth_credential.service_account
7792
credentials = service_account.Credentials.from_service_account_info(
7893
config.service_account_credential.model_dump(), scopes=config.scopes
7994
)
@@ -82,11 +97,8 @@ def exchange(self) -> AuthCredential:
8297
credentials.refresh(Request())
8398

8499
return AuthCredential(
85-
auth_type=AuthCredentialTypes.HTTP,
86-
http=HttpAuth(
87-
scheme="bearer",
88-
credentials=HttpCredentials(token=credentials.token),
89-
),
100+
auth_type=AuthCredentialTypes.OAUTH2,
101+
google_oauth2_json=credentials.to_json(),
90102
)
91103
except Exception as e:
92104
raise ValueError(f"Failed to exchange service account token: {e}") from e
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Credential refresher registry."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Dict
20+
from typing import Optional
21+
22+
from google.adk.auth.auth_credential import AuthCredentialTypes
23+
from google.adk.utils.feature_decorator import experimental
24+
25+
from .base_credential_refresher import BaseCredentialRefresher
26+
27+
28+
@experimental
29+
class CredentialRefresherRegistry:
30+
"""Registry for credential refresher instances."""
31+
32+
def __init__(self):
33+
self._refreshers: Dict[AuthCredentialTypes, BaseCredentialRefresher] = {}
34+
35+
def register(
36+
self,
37+
credential_type: AuthCredentialTypes,
38+
refresher_instance: BaseCredentialRefresher,
39+
) -> None:
40+
"""Register a refresher instance for a credential type.
41+
42+
Args:
43+
credential_type: The credential type to register for.
44+
refresher_instance: The refresher instance to register.
45+
"""
46+
self._refreshers[credential_type] = refresher_instance
47+
48+
def get_refresher(
49+
self, credential_type: AuthCredentialTypes
50+
) -> Optional[BaseCredentialRefresher]:
51+
"""Get the refresher instance for a credential type.
52+
53+
Args:
54+
credential_type: The credential type to get refresher for.
55+
56+
Returns:
57+
The refresher instance if registered, None otherwise.
58+
"""
59+
return self._refreshers.get(credential_type)

0 commit comments

Comments
 (0)
0