8000 feat(auth)!: expose access_token and refresh_token at top level of au… · jsondai/adk-python@956fb91 · GitHub
[go: up one dir, main page]

Skip to content

Commit 956fb91

Browse files
seanzhougooglecopybara-github
authored andcommitted
feat(auth)!: expose access_token and refresh_token at top level of auth credentails
BREAKING CHANGE: `token` attribute of OAuth2Auth credentials used to be a dict containing both access_token and refresh_token, given that may cause confusions, now we replace it with access_token and refresh_token at top level of the auth credentials PiperOrigin-RevId: 750346172
1 parent 49d8c0f commit 956fb91

File tree

6 files changed

+21
-20
lines changed

6 files changed

+21
-20
lines changed

src/google/adk/auth/auth_credential.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ class OAuth2Auth(BaseModelWithConfig):
6666
redirect_uri: Optional[str] = None
6767
auth_response_uri: Optional[str] = None
6868
auth_code: Optional[str] = None
69-
token: Optional[Dict[str, Any]] = None
69+
access_token: Optional[str] = None
70+
refresh_token: Optional[str] = None
7071

7172

7273
class ServiceAccountCredential(BaseModelWithConfig):

src/google/adk/auth/auth_handler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def exchange_auth_token(
8282
or not auth_credential.oauth2
8383
or not auth_credential.oauth2.client_id
8484
or not auth_credential.oauth2.client_secret
85-
or auth_credential.oauth2.token
85+
or auth_credential.oauth2.access_token
86+
or auth_credential.oauth2.refresh_token
8687
):
8788
return self.auth_config.exchanged_auth_credential
8889

@@ -93,7 +94,7 @@ def exchange_auth_token(
9394
redirect_uri=auth_credential.oauth2.redirect_uri,
9495
state=auth_credential.oauth2.state,
9596
)
96-
token = client.fetch_token(
97+
tokens = client.fetch_token(
9798
token_endpoint,
9899
authorization_response=auth_credential.oauth2.auth_response_uri,
99100
code=auth_credential.oauth2.auth_code,
@@ -102,7 +103,10 @@ def exchange_auth_token(
102103

103104
updated_credential = AuthCredential(
104105
auth_type=AuthCredentialTypes.OAUTH2,
105-
oauth2=OAuth2Auth(token=dict(token)),
106+
oauth2=OAuth2Auth(
107+
access_token=tokens.get("access_token"),
108+
refresh_token=tokens.get("refresh_token"),
109+
),
106110
)
107111
return updated_credential
108112

src/google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def generate_auth_token(
6969
HTTP bearer token cannot be generated, return the original credential.
7070
"""
7171

72-
if "access_token" not in auth_credential.oauth2.token:
72+
if not auth_credential.oauth2.access_token:
7373
return auth_credential
7474

7575
# Return the access token as a bearer token.
@@ -78,7 +78,7 @@ def generate_auth_token(
7878
http=HttpAuth(
7979
scheme="bearer",
8080
credentials=HttpCredentials(
81-
token=auth_credential.oauth2.token["access_token"]
81+
token=auth_credential.oauth2.access_token
8282
),
8383
),
8484
)
@@ -111,7 +111,7 @@ def exchange_credential(
111111
return auth_credential
112112

113113
# If access token is exchanged, exchange a HTTPBearer token.
114-
if auth_credential.oauth2.token:
114+
if auth_credential.oauth2.access_token:
115115
return self.generate_auth_token(auth_credential)
116116

117117
return None

tests/unittests/auth/test_auth_handler.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,8 @@ def oauth2_credentials_with_token():
126126
client_id="mock_client_id",
127127
client_secret="mock_client_secret",
128128
redirect_uri="https://example.com/callback",
129-
token={
130-
"access_token": "mock_access_token",
131-
"token_type": "bearer",
132-
"expires_in": 3600,
133-
"refresh_token": "mock_refresh_token",
134-
},
129+
access_token="mock_access_token",
130+
refresh_token="mock_refresh_token",
135131
),
136132
)
137133

@@ -458,7 +454,7 @@ def test_oauth_scheme(self, mock_exchange_token, auth_config_with_exchanged):
458454
"""Test with an OAuth auth scheme."""
459455
mock_exchange_token.return_value = AuthCredential(
460456
auth_type=AuthCredentialTypes.OAUTH2,
461-
oauth2=OAuth2Auth(token={"access_token": "exchanged_token"}),
457+
oauth2=OAuth2Auth(access_token="exchanged_token"),
462458
)
463459

464460
handler = AuthHandler(auth_config_with_exchanged)
@@ -573,6 +569,6 @@ def test_successful_token_exchange(self, auth_config_with_auth_code):
573569
handler = AuthHandler(auth_config_with_auth_code)
574570
result = handler. 9E81 exchange_auth_token()
575571

576-
assert result.oauth2.token["access_token"] == "mock_access_token"
577-
assert result.oauth2.token["refresh_token"] == "mock_refresh_token"
572+
assert result.oauth2.access_token == "mock_access_token"
573+
assert result.oauth2.refresh_token == "mock_refresh_token"
578574
assert result.auth_type == AuthCredentialTypes.OAUTH2

tests/unittests/flows/llm_flows/test_functions_request_euc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_function_get_auth_response():
246246
oauth2=OAuth2Auth(
247247
client_id='oauth_client_id_1',
248248
client_secret='oauth_client_secret1',
249-
token={'access_token': 'token1'},
249+
access_token='token1',
250250
),
251251
),
252252
)
@@ -277,7 +277,7 @@ def test_function_get_auth_response():
277277
oauth2=OAuth2Auth(
278278
client_id='oauth_client_id_2',
279279
client_secret='oauth_client_secret2',
280-
token={'access_token': 'token2'},
280+
access_token='token2',
281281
),
282282
),
283283
)

tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_oauth2_exchanger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_generate_auth_token_success(
110110
client_secret="test_secret",
111111
redirect_uri="http://localhost:8080",
112112
auth_response_uri="https://example.com/callback?code=test_code",
113-
token={"access_token": "test_access_token"},
113+
access_token="test_access_token",
114114
),
115115
)
116116
updated_credential = oauth2_exchanger.generate_auth_token(auth_credential)
@@ -131,7 +131,7 @@ def test_exchange_credential_generate_auth_token(
131131
client_secret="test_secret",
132132
redirect_uri="http://localhost:8080",
133133
auth_response_uri="https://example.com/callback?code=test_code",
134-
token={"access_token": "test_access_token"},
134+
access_token="test_access_token",
135135
),
136136
)
137137

0 commit comments

Comments
 (0)
0