8000 chore: Add base credential service interface (WIP) · devevignesh/adk-python@8ebf229 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ebf229

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Add base credential service interface (WIP)
PiperOrigin-RevId: 771358480
1 parent b51a1f4 commit 8ebf229

File tree

7 files changed

+176
-92
lines changed

7 files changed

+176
-92
lines changed

contributing/samples/oauth_calendar_agent/agent.py

Lines changed: 41 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from google.adk.auth import AuthCredentialTypes
2828
from google.adk.auth import OAuth2Auth
2929
from google.adk.tools import ToolContext
30+
from google.adk.tools.authenticated_tool.base_authenticated_tool import AuthenticatedFunctionTool
31+
from google.adk.tools.authenticated_tool.credentials_store import ToolContextCredentialsStore
3032
from google.adk.tools.google_api_tool import CalendarToolset
3133
from google.auth.transport.requests import Request
3234
from google.oauth2.credentials import Credentials
@@ -56,6 +58,7 @@ def list_calendar_events(
5658
end_time: str,
5759
limit: int,
5860
tool_context: ToolContext,
61+
credential: AuthCredential,
5962
) -> list[dict]:
6063
"""Search for calendar events.
6164
@@ -80,84 +83,11 @@ def list_calendar_events(
8083
Returns:
8184
list[dict]: A list of events that match the search criteria.
8285
"""
83-
creds = None
84-
85-
# Check if the tokes were already in the session state, which means the user
86-
# has already gone through the OAuth flow and successfully authenticated and
87-
# authorized the tool to access their calendar.
88-
if "calendar_tool_tokens" in tool_context.state:
89-
creds = Credentials.from_authorized_user_info(
90-
tool_context.state["calendar_tool_tokens"], SCOPES
91-
)
92-
if not creds or not creds.valid:
93-
# If the access token is expired, refresh it with the refresh token.
94-
if creds and creds.expired and creds.refresh_token:
95-
creds.refresh(Request())
96-
else:
97-
auth_scheme = OAuth2(
98-
flows=OAuthFlows(
99-
authorizationCode=OAuthFlowAuthorizationCode(
100-
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
101-
tokenUrl="https://oauth2.googleapis.com/token",
102-
scopes={
103-
"https://www.googleapis.com/auth/calendar": (
104-
"See, edit, share, and permanently delete all the"
105-
" calendars you can access using Google Calendar"
106-
)
107-
},
108-
)
109-
)
110-
)
111-
auth_credential = AuthCredential(
112-
auth_type=AuthCredentialTypes.OAUTH2,
113-
oauth2=OAuth2Auth(
114-
client_id=oauth_client_id, client_secret=oauth_client_secret
115-
),
116-
)
117-
# If the user has not gone through the OAuth flow before, or the refresh
118-
# token also expired, we need to ask users to go through the OAuth flow.
119-
# First we check whether the user has just gone through the OAuth flow and
120-
# Oauth response is just passed back.
121-
auth_response = tool_context.get_auth_response(
122-
AuthConfig(
123-
auth_scheme=auth_scheme, raw_auth_credential=auth_credential
124-
)
125-
)
126-
if auth_response:
127-
# ADK exchanged the access token already for us
128-
access_token = auth_response.oauth2.access_token
129-
refresh_token = auth_response.oauth2.refresh_token
130-
131-
creds = Credentials(
132-
token=access_token,
133-
refresh_token=refresh_token,
134-
token_uri=auth_scheme.flows.authorizationCode.tokenUrl,
135-
client_id=oauth_client_id,
136-
client_secret=oauth_client_secret,
137-
scopes=list(auth_scheme.flows.authorizationCode.scopes.keys()),
138-
)
139-
else:
140-
# If there are no auth response which means the user has not gone
141-
# through the OAuth flow yet, we need to ask users to go through the
142-
# OAuth flow.
143-
tool_context.request_credential(
144-
AuthConfig(
145-
auth_scheme=auth_scheme,
146-
raw_auth_credential=auth_credential,
147-
)
148-
)
149-
# The return value is optional and could be any dict object. It will be
150-
# wrapped in a dict with key as 'result' and value as the return value
151-
# if the object returned is not a dict. This response will be passed
152-
# to LLM to generate a user friendly message. e.g. LLM will tell user:
153-
# "I need your authorization to access your calendar. Please authorize
154-
# me so I can check your meetings for today."
155-
return "Need User Authorization to access their calendar."
156-
# We store the access token and refresh token in the session state for the
157-
# next runs. This is just an example. On production, a tool should store
158-
# those credentials in some secure store or properly encrypt it before store
159-
# it in the session state.
160-
tool_context.state["calendar_tool_tokens"] = json.loads(creds.to_json())
86+
87+
creds = Credentials(
88+
token=credential.oauth2.access_token,
89+
refresh_token=credential.oauth2.refresh_token,
90+
)
16191

16292
service = build("calendar", "v3", credentials=creds)
16393
events_result = (
@@ -208,6 +138,38 @@ def update_time(callback_context: CallbackContext):
208138
209139
Currnet time: {_time}
210140
""",
211-
tools=[list_calendar_events, calendar_toolset],
141+
tools=[
142+
AuthenticatedFunctionTool(
143+
func=list_calendar_events,
144+
auth_config=AuthConfig(
145+
auth_scheme=OAuth2(
146+
flows=OAuthFlows(
147+
authorizationCode=OAuthFlowAuthorizationCode(
148+
authorizationUrl=(
149+
"https://accounts.google.com/o/oauth2/auth"
150+
),
151+
tokenUrl="https://oauth2.googleapis.com/token",
152+
scopes={
153+
"https://www.googleapis.com/auth/calendar": (
154+
"See, edit, share, and permanently delete"
155+
" all the calendars you can access using"
156+
" Google Calendar"
157+
)
158+
},
159+
)
160+
)
161+
),
162+
raw_auth_credential=AuthCredential(
163+
auth_type=AuthCredentialTypes.OAUTH2,
164+
oauth2=OAuth2Auth(
165+
client_id=oauth_client_id,
166+
client_secret=oauth_client_secret,
167+
),
168+
),
169+
),
170+
credential_store=ToolContextCredentialsStore(),
171+
),
172+
calendar_toolset,
173+
],
212174
before_agent_callback=update_time,
213175
)

src/google/adk/auth/auth_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def exchange_auth_token(
4949

5050
def parse_and_store_auth_response(self, state: State) -> None:
5151

52-
credential_key = "temp:" + self.auth_config.get_credential_key()
52+
credential_key = "temp:" + self.auth_config.credential_key
5353

5454
state[credential_key] = self.auth_config.exchanged_auth_credential
5555
if not isinstance(
@@ -67,7 +67,7 @@ def _validate(self) -> None:
6767
raise ValueError("auth_scheme is empty.")
6868

6969
def get_auth_response(self, state: State) -> AuthCredential:
70-
credential_key = "temp:" + self.auth_config.get_credential_key()
70+
credential_key = "temp:" + self.auth_config.credential_key
7171
return state.get(credential_key, None)
7272

7373
def generate_auth_request(self) -> AuthConfig:

src/google/adk/auth/auth_tool.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Optional
18+
19+
from typing_extensions import deprecated
20+
1721
from .auth_credential import AuthCredential
1822
from .auth_credential import BaseModelWithConfig
1923
from .auth_schemes import AuthScheme
@@ -45,11 +49,23 @@ class AuthConfig(BaseModelWithConfig):
4549
this field to guide the user through the OAuth2 flow and fill auth response in
4650
this field"""
4751

52+
credential_key: Optional[str] = None
53+
"""A user specified key used to load and save this credential in a credential
54+
service.
55+
"""
56+
57+
def __init__(self, **data):
58+
super().__init__(**data)
59+
if self.credential_key:
60+
return
61+
self.credential_key = self.get_credential_key()
62+
63+
@deprecated("This method is deprecated. Use credential_key instead.")
4864
def get_credential_key(self):
49-
"""Generates a hash key based on auth_scheme and raw_auth_credential. This
50-
hash key can be used to store / retrieve exchanged_auth_credential in a
51-
credentials store.
65+
"""Builds a hash key based on auth_scheme and raw_auth_credential used to
66+
save / load this credential to / from a credentials service.
5267
"""
68+
5369
auth_scheme = self.auth_scheme
5470

5571
if auth_scheme.model_extra:
@@ -62,7 +78,7 @@ def get_credential_key(self):
6278
)
6379

6480
auth_credential = self.raw_auth_credential
65-
if auth_credential.model_extra:
81+
if auth_credential and auth_credential.model_extra:
6682
auth_credential = auth_credential.model_copy(deep=True)
6783
auth_credential.model_extra.clear()
6884
credential_name = (
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from abc import ABC
18+
from abc import abstractmethod
19+
from typing import Optional
20+
21+
from ...tools.tool_context import ToolContext
22+
from ...utils.feature_decorator import working_in_progress
23+
from ..auth_credential import AuthCredential
24+
from ..auth_tool import AuthConfig
25+
26+
27+
@working_in_progress("Implementation are in progress. Don't use it for now.")
28+
class BaseCredentialService(ABC):
29+
"""Abstract class for Service that loads / saves tool credentials from / to
30+
the backend credential store."""
31+
32+
@abstractmethod
33+
async def load_credential(
34+
self,
35+
auth_config: AuthConfig,
36+
tool_context: ToolContext,
37+
) -> Optional[AuthCredential]:
38+
"""
39+
Loads the credential by auth config and current tool context from the
40+
backend credential store.
41+
42+
Args:
43+
auth_config: The auth config which contains the auth scheme and auth
44+
credential information. auth_config.get_credential_key will be used to
45+
build the key to load the credential.
46+
47+
tool_context: The context of the current invocation when the tool is
48+
trying to load the credential.
49+
50+
Returns:
51+
Optional[AuthCredential]: the credential saved in the store.
52+
53+
"""
54+
55+
@abstractmethod
56+
async def save_credential(
57+
self,
58+
auth_config: AuthConfig,
59+
tool_context: ToolContext,
60+
) -> None:
61+
"""
62+
Saves the exchanged_auth_credential in auth config to the backend credential
63+
store.
64+
65+
Args:
66+
auth_config: The auth config which contains the auth scheme and auth
67+
credential information. auth_config.get_credential_key will be used to
68+
build the key to save the credential.
69+
70+
tool_context: The context of the current invocation when the tool is
71+
trying to save the credential.
72+
73+
Returns:
74+
None
75+
"""

tests/unittests/auth/test_auth_config.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,28 @@ def auth_config(oauth2_auth_scheme, oauth2_credentials):
6868
)
6969

7070

71-
def test_get_credential_key(auth_config):
71+
@pytest.fixture
72+
def auth_config_with_key(oauth2_auth_scheme, oauth2_credentials):
73+
"""Create an AuthConfig for testing."""
74+
75+
return AuthConfig(
76+
auth_scheme=oauth2_auth_scheme,
77+
raw_auth_credential=oauth2_credentials,
78+
credential_key="test_key",
79+
)
80+
81+
82+
def test_custom_credential_key(auth_config_with_key):
83+
"""Test using custom credential key."""
84+
85+
key = auth_config_with_key.credential_key
86+
assert key == "test_key"
87+
88+
89+
def test_credential_key(auth_config):
7290
"""Test generating a unique credential key."""
7391

74-
key = auth_config.get_credential_key()
92+
key = auth_config.credential_key
7593
assert key.startswith("adk_oauth2_")
7694
assert "_oauth2_" in key
7795

@@ -80,8 +98,8 @@ def test_get_credential_key_with_extras(auth_config):
8098
"""Test generating a key when model_extra exists."""
8199
# Add model_extra to test cleanup
82100

83-
original_key = auth_config.get_credential_key()
84-
key = auth_config.get_credential_key()
101+
original_key = auth_config.credential_key
102+
key = auth_config.credential_key
85103

86104
auth_config.auth_scheme.model_extra["extra_field"] = "value"
87105
auth_config.raw_auth_credential.model_extra["extra_field"] = "value"

tests/unittests/auth/test_auth_handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def test_get_auth_response_exists(
387387
state = MockState()
388388

389389
# Store a credential in the state
390-
credential_key = auth_config.get_credential_key()
390+
credential_key = auth_config.credential_key
391391
state["temp:" + credential_key] = oauth2_credentials_with_auth_uri
392392

393393
result = handler.get_auth_response(state)
@@ -418,7 +418,7 @@ def test_non_oauth_scheme(self, auth_config_with_exchanged):
418418

419419
handler.parse_and_store_auth_response(state)
420420

421-
credential_key = auth_config.get_credential_key()
421+
credential_key = auth_config.credential_key
422422
assert (
423423
state["temp:" + credential_key] == auth_config.exchanged_auth_credential
424424
)
@@ -436,7 +436,7 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
436436

437437
handler.parse_and_store_auth_response(state)
438438

439-
credential_key = auth_config_with_exchanged.get_credential_key()
439+
credential_key = auth_config_with_exchanged.credential_key
440440
assert state["temp:" + credential_key] == mock_exchange_token.return_value
441441
assert mock_exchange_token.called
442442

0 commit comments

Comments
 (0)
0