8000 refactor: Extract util method from OAuth2 credential fetcher for reuse · jsdzhang/adk-python@94caccc · GitHub
[go: up one dir, main page]

Skip to content

Commit 94caccc

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: Extract util method from OAuth2 credential fetcher for reuse
Context: we'd like to separate fetcher into exchanger and refresher later. This cl help to extract the common utility that will be used by both exchanger and refresher. PiperOrigin-RevId: 772257995
1 parent 476805d commit 94caccc

File tree

5 files changed

+284
-363
lines changed

5 files changed

+284
-363
lines changed

src/google/adk/auth/oauth2_credential_fetcher.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,15 @@
1515
from __future__ import annotations
1616

1717
import logging
18-
from typing import Optional
19-
from typing import Tuple
20-
21-
from fastapi.openapi.models import OAuth2
2218

2319
from ..utils.feature_decorator import experimental
2420
from .auth_credential import AuthCredential
2521
from .auth_schemes import AuthScheme
2622
from .auth_schemes import OAuthGrantType
27-
from .auth_schemes import OpenIdConnectWithConfig
23+
from .oauth2_credential_util import create_oauth2_session
24+
from .oauth2_credential_util import update_credential_with_tokens
2825

2926
try:
30-
from authlib.integrations.requests_client import OAuth2Session
3127
from authlib.oauth2.rfc6749 import OAuth2Token
3228

3329
AUTHLIB_AVIALABLE = True
@@ -50,45 +46,6 @@ def __init__(
5046
self._auth_scheme = auth_scheme
5147
self._auth_credential = auth_credential
5248

53-
def _oauth2_session(self) -> Tuple[Optional[OAuth2Session], Optional[str]]:
54-
auth_scheme = self._auth_scheme
55-
auth_credential = self._auth_credential
56-
57-
if isinstance(auth_scheme, OpenIdConnectWithConfig):
58-
if not hasattr(auth_scheme, "token_endpoint"):
59-
return None, None
60-
token_endpoint = auth_scheme.token_endpoint
61-
scopes = auth_scheme.scopes
62-
elif isinstance(auth_scheme, OAuth2):
63-
if (
64-
not auth_scheme.flows.authorizationCode
65-
or not auth_scheme.flows.authorizationCode.tokenUrl
66-
):
67-
return None, None
68-
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
69-
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
70-
else:
71-
return None, None
72-
73-
if (
74-
not auth_credential
75-
or not auth_credential.oauth2
76-
or not auth_credential.oauth2.client_id
77-
or not auth_credential.oauth2.client_secret
78-
):
79-
return None, None
80-
81-
return (
82-
OAuth2Session(
83-
auth_credential.oauth2.client_id,
84-
auth_credential.oauth2.client_secret,
85-
scope=" ".join(scopes),
86-
redirect_uri=auth_credential.oauth2.redirect_uri,
87-
state=auth_credential.oauth2.state,
88-
),
89-
token_endpoint,
90-
)
91-
9249
def _update_credential(self, tokens: OAuth2Token) -> None:
9350
self._auth_credential.oauth2.access_token = tokens.get("access_token")
9451
self._auth_credential.oauth2.refresh_token = tokens.get("refresh_token")
@@ -114,7 +71,9 @@ def exchange(self) -> AuthCredential:
11471
):
11572
return self._auth_credential
11673

117-
client, token_endpoint = self._oauth2_session()
74+
client, token_endpoint = create_oauth2_session(
75+
self._auth_scheme, self._auth_credential
76+
)
11877
if not client:
11978
logger.warning("Could not create OAuth2 session for token exchange")
12079
return self._auth_credential
@@ -126,7 +85,7 @@ def exchange(self) -> AuthCredential:
12685
code=self._auth_credential.oauth2.auth_code,
12786
grant_type=OAuthGrantType.AUTHORIZATION_CODE,
12887
)
129-
self._update_credential(tokens)
88+
update_credential_with_tokens(self._auth_credential, tokens)
13089
logger.info("Successfully exchanged OAuth2 tokens")
13190
except Exception as e:
13291
logger.error("Failed to exchange OAuth2 tokens: %s", e)
@@ -151,7 +110,9 @@ def refresh(self) -> AuthCredential:
151110
"expires_at": credential.oauth2.expires_at,
152111
"expires_in": credential.oauth2.expires_in,
153112
}).is_expired():
154-
client, token_endpoint = self._oauth2_session()
113+
client, token_endpoint = create_oauth2_session(
114+
self._auth_scheme, self._auth_credential
115+
)
155116
if not client:
156117
logger.warning("Could not create OAuth2 session for token refresh")
157118
return credential
@@ -161,7 +122,7 @@ def refresh(self) -> AuthCredential:
161122
url=token_endpoint,
162123
refresh_token=credential.oauth2.refresh_token,
163124
)
164-
self._update_credential(tokens)
125+
update_credential_with_tokens(self._auth_credential, tokens)
165126
logger.info("Successfully refreshed OAuth2 tokens")
166127
except Exception as e:
167128
logger.error("Failed to refresh OAuth2 tokens: %s", e)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
from typing import Optional
19+
from typing import Tuple
20+
21+
from fastapi.openapi.models import OAuth2
22+
23+
from ..utils.feature_decorator import experimental
24+
from .auth_credential import AuthCredential
25+
from .auth_schemes import AuthScheme
26+
from .auth_schemes import OpenIdConnectWithConfig
27+
28+
try:
29+
from authlib.integrations.requests_client import OAuth2Session
30+
from authlib.oauth2.rfc6749 import OAuth2Token
31+
32+
AUTHLIB_AVIALABLE = True
33+
except ImportError:
34+
AUTHLIB_AVIALABLE = False
35+
36+
37+
logger = logging.getLogger("google_adk." + __name__)
38+
39+
40+
@experimental
41+
def create_oauth2_session(
42+
auth_scheme: AuthScheme,
43+
auth_credential: AuthCredential,
44+
) -> Tuple[Optional[OAuth2Session], Optional[str]]:
45+
"""Create an OAuth2 session for token operations.
46+
47+
Args:
48+
auth_scheme: The authentication scheme configuration.
49+
auth_credential: The authentication credential.
50+
51+
Returns:
52+
Tuple of (OAuth2Session, token_endpoint) or (None, None) if cannot create session.
53+
"""
54+
if isinstance(auth_scheme, OpenIdConnectWithConfig):
55+
if not hasattr(auth_scheme, "token_endpoint"):
56+
return None, None
57+
token_endpoint = auth_scheme.token_endpoint
58+
scopes = auth_scheme.scopes
59+
elif isinstance(auth_scheme, OAuth2):
60+
if (
61 10000 +
not auth_scheme.flows.authorizationCode
62+
or not auth_scheme.flows.authorizationCode.tokenUrl
63+
):
64+
return None, None
65+
token_endpoint = auth_scheme.flows.authorizationCode.tokenUrl
66+
scopes = list(auth_scheme.flows.authorizationCode.scopes.keys())
67+
else:
68+
return None, None
69+
70+
if (
71+
not auth_credential
72+
or not auth_credential.oauth2
73+
or not auth_credential.oauth2.client_id
74+
or not auth_credential.oauth2.client_secret
75+
):
76+
return None, None
77+
78+
return (
79+
OAuth2Session(
80+
auth_credential.oauth2.client_id,
81+
auth_credential.oauth2.client_secret,
82+
scope=" ".join(scopes),
83+
redirect_uri=auth_credential.oauth2.redirect_uri,
84+
state=auth_credential.oauth2.state,
85+
),
86+
token_endpoint,
87+
)
88+
89+
90+
@experimental
91+
def update_credential_with_tokens(
92+
auth_credential: AuthCredential, tokens: OAuth2Token
93+
) -> None:
94+
"""Update the credential with new tokens.
95+
96+
Args:
97+
auth_credential: The authentication credential to update.
98+
tokens: The OAuth2Token object containing new token information.
99+
"""
100+
auth_credential.oauth2.access_token = tokens.get("access_token")
101+
auth_credential.oauth2.refresh_token = tokens.get("refresh_token")
102+
auth_credential.oauth2.expires_at = (
103+
int(tokens.get("expires_at")) if tokens.get("expires_at") else None
104+
)
105+
auth_credential.oauth2.expires_in = (
106+
int(tokens.get("expires_in")) if tokens.get("expires_in") else None
107+
)

tests/unittests/auth/test_auth_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def test_credentials_with_token(
538538
assert result == oauth2_credentials_with_token
539539

540540
@patch(
541-
"google.adk.auth.oauth2_credential_fetcher.OAuth2Session",
541+
"google.adk.auth.oauth2_credential_util.OAuth2Session",
542542
MockOAuth2Session,
543543
)
544544
def test_successful_token_exchange(self, auth_config_with_auth_code):

0 commit comments

Comments
 (0)
0