15
15
from __future__ import annotations
16
16
17
17
import logging
18
- from typing import Optional
19
- from typing import Tuple
20
-
21
- from fastapi .openapi .models import OAuth2
22
18
23
19
from ..utils .feature_decorator import experimental
24
20
from .auth_credential import AuthCredential
25
21
from .auth_schemes import AuthScheme
26
22
from .auth_schemes import OAuthGrantType
27
- from .auth_schemes import OpenIdConnectWithConfig
23
+ from .oauth2_credential_util import create_oauth2_session
24
+ from .oauth2_credential_util import update_credential_with_tokens
28
25
29
26
try :
30
- from authlib .integrations .requests_client import OAuth2Session
31
27
from authlib .oauth2 .rfc6749 import OAuth2Token
32
28
33
29
AUTHLIB_AVIALABLE = True
@@ -50,45 +46,6 @@ def __init__(
50
46
self ._auth_scheme = auth_scheme
51
47
self ._auth_credential = auth_credential
52
48
53
- def _oauth2_session (self ) -> Tuple [Optional [OAuth2Session ], Optional [str ]]:
54
- auth_scheme = self ._auth_scheme
55
- auth_credential = self ._auth_credential
56
-
57
- if isinstance (auth_scheme , OpenIdConnectWithConfig ):
58
- if not hasattr (auth_scheme , "token_endpoint" ):
59
- return None , None
60
- token_endpoint = auth_scheme .token_endpoint
61
- scopes = auth_scheme .scopes
62
- elif isinstance (auth_scheme , OAuth2 ):
63
- if (
64
- not auth_scheme .flows .authorizationCode
65
- or not auth_scheme .flows .authorizationCode .tokenUrl
66
- ):
67
- return None , None
68
- token_endpoint = auth_scheme .flows .authorizationCode .tokenUrl
69
- scopes = list (auth_scheme .flows .authorizationCode .scopes .keys ())
70
- else :
71
- return None , None
72
-
73
- if (
74
- not auth_credential
75
- or not auth_credential .oauth2
76
- or not auth_credential .oauth2 .client_id
77
- or not auth_credential .oauth2 .client_secret
78
- ):
79
- return None , None
80
-
81
- return (
82
- OAuth2Session (
83
- auth_credential .oauth2 .client_id ,
84
- auth_credential .oauth2 .client_secret ,
85
- scope = " " .join (scopes ),
86
- redirect_uri = auth_credential .oauth2 .redirect_uri ,
87
- state = auth_credential .oauth2 .state ,
88
- ),
89
- token_endpoint ,
90
- )
91
-
92
49
def _update_credential (self , tokens : OAuth2Token ) -> None :
93
50
self ._auth_credential .oauth2 .access_token = tokens .get ("access_token" )
94
51
self ._auth_credential .oauth2 .refresh_token = tokens .get ("refresh_token" )
@@ -114,7 +71,9 @@ def exchange(self) -> AuthCredential:
114
71
):
115
72
return self ._auth_credential
116
73
117
- client , token_endpoint = self ._oauth2_session ()
74
+ client , token_endpoint = create_oauth2_session (
75
+ self ._auth_scheme , self ._auth_credential
76
+ )
118
77
if not client :
119
78
logger .warning ("Could not create OAuth2 session for token exchange" )
120
79
return self ._auth_credential
@@ -126,7 +85,7 @@ def exchange(self) -> AuthCredential:
126
85
code = self ._auth_credential .oauth2 .auth_code ,
127
86
grant_type = OAuthGrantType .AUTHORIZATION_CODE ,
128
87
)
129
- self ._update_credential ( tokens )
88
+ update_credential_with_tokens ( self ._auth_credential , tokens )
130
89
logger .info ("Successfully exchanged OAuth2 tokens" )
131
90
except Exception as e :
132
91
logger .error ("Failed to exchange OAuth2 tokens: %s" , e )
@@ -151,7 +110,9 @@ def refresh(self) -> AuthCredential:
151
110
"expires_at" : credential .oauth2 .expires_at ,
152
111
"expires_in" : credential .oauth2 .expires_in ,
153
112
}).is_expired ():
154
- client , token_endpoint = self ._oauth2_session ()
113
+ client , token_endpoint = create_oauth2_session (
114
+ self ._auth_scheme , self ._auth_credential
115
+ )
155
116
if not client :
156
117
logger .warning ("Could not create OAuth2 session for token refresh" )
157
118
return credential
@@ -161,7 +122,7 @@ def refresh(self) -> AuthCredential:
161
122
url = token_endpoint ,
162
123
refresh_token = credential .oauth2 .refresh_token ,
163
124
)
164
- self ._update_credential ( tokens )
125
+ update_credential_with_tokens ( self ._auth_credential , tokens )
165
126
logger .info ("Successfully refreshed OAuth2 tokens" )
166
127
except Exception as e :
167
128
logger .error ("Failed to refresh OAuth2 tokens: %s" , e )
0 commit comments