From ae105b61841c3cb1be5a2ad30d2287e747cfc551 Mon Sep 17 00:00:00 2001 From: yangben Date: Tue, 6 May 2025 20:08:03 +0800 Subject: [PATCH 1/3] Fix the issue of get Authorization header fails during bearer auth --- src/mcp/server/auth/middleware/bearer_auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 295605af7..a767f3399 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -34,7 +34,7 @@ def __init__( self.provider = provider async def authenticate(self, conn: HTTPConnection): - auth_header = conn.headers.get("Authorization") + auth_header = conn.headers.get("Authorization") or conn.headers.get("authorization") if not auth_header or not auth_header.startswith("Bearer "): return None From 7b7a73310fcc7a5d9688714370377f4ac03346c5 Mon Sep 17 00:00:00 2001 From: yangben Date: Tue, 6 May 2025 20:19:48 +0800 Subject: [PATCH 2/3] Fix the issue of get Authorization header fails during bearer auth --- src/mcp/server/auth/middleware/bearer_auth.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index a767f3399..dde49848c 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -34,7 +34,9 @@ def __init__( self.provider = provider async def authenticate(self, conn: HTTPConnection): - auth_header = conn.headers.get("Authorization") or conn.headers.get("authorization") + auth_header = conn.headers.get("Authorization") or conn.headers.get( + "authorization" + ) if not auth_header or not auth_header.startswith("Bearer "): return None From ef64a4a97aafb0d3624fbc14546edabeba253201 Mon Sep 17 00:00:00 2001 From: yangben Date: Wed, 7 May 2025 21:44:45 +0800 Subject: [PATCH 3/3] Fix the case sensitivity issue in headers of bearer auth --- src/mcp/server/auth/middleware/bearer_auth.py | 11 +++- .../auth/middleware/test_bearer_auth.py | 61 +++++++++++++++++++ 2 files changed, 69 insertions(+), 3 deletions(-) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index dde49848c..30b5e2ba6 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -34,10 +34,15 @@ def __init__( self.provider = provider async def authenticate(self, conn: HTTPConnection): - auth_header = conn.headers.get("Authorization") or conn.headers.get( - "authorization" + auth_header = next( + ( + conn.headers.get(key) + for key in conn.headers + if key.lower() == "authorization" + ), + None, ) - if not auth_header or not auth_header.startswith("Bearer "): + if not auth_header or not auth_header.lower().startswith("bearer "): return None token = auth_header[7:] # Remove "Bearer " prefix diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py index 9acb5ff09..e8c17a4c4 100644 --- a/tests/server/auth/middleware/test_bearer_auth.py +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -7,6 +7,7 @@ import pytest from starlette.authentication import AuthCredentials +from starlette.datastructures import Headers from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.types import Message, Receive, Scope, Send @@ -221,6 +222,66 @@ async def test_token_without_expiry( assert user.access_token == no_expiry_access_token assert user.scopes == ["read", "write"] + async def test_lowercase_bearer_prefix( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test with lowercase 'bearer' prefix in Authorization header""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"Authorization": "bearer valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + + async def test_mixed_case_bearer_prefix( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test with mixed 'BeArEr' prefix in Authorization header""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"authorization": "BeArEr valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + + async def test_mixed_case_authorization_header( + self, + mock_oauth_provider: OAuthAuthorizationServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test authentication with mixed 'Authorization' header.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + headers = Headers({"AuThOrIzAtIoN": "BeArEr valid_token"}) + scope = {"type": "http", "headers": headers.raw} + request = Request(scope) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + @pytest.mark.anyio class TestRequireAuthMiddleware: