From c7183f6d55541f833b52eb02d22b920da78261a0 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:33:35 -0700 Subject: [PATCH 01/28] Exploratory implementation --- .gitignore | 5 +- .../AuthorizationExample.csproj | 14 + samples/AuthorizationExample/Program.cs | 81 ++++ .../McpAuthorizationException.cs | 77 ++++ src/ModelContextProtocol/McpErrorCode.cs | 16 + .../Protocol/Auth/AuthorizationContext.cs | 91 ++++ .../Auth/AuthorizationServerMetadata.cs | 69 +++ .../Protocol/Auth/AuthorizationService.cs | 423 ++++++++++++++++++ .../Protocol/Auth/ClientMetadata.cs | 99 ++++ .../Auth/ClientRegistrationResponse.cs | 33 ++ .../Auth/DefaultAuthorizationHandler.cs | 265 +++++++++++ .../Protocol/Auth/IAuthorizationHandler.cs | 24 + .../Protocol/Auth/ResourceMetadata.cs | 39 ++ .../Protocol/Auth/TokenResponse.cs | 39 ++ .../Transport/SseClientSessionTransport.cs | 177 ++++++-- .../Protocol/Transport/SseClientTransport.cs | 117 +++++ .../Transport/SseClientTransportOptions.cs | 29 ++ .../Utils/SynchronizedValue.cs | 75 ++++ 18 files changed, 1634 insertions(+), 39 deletions(-) create mode 100644 samples/AuthorizationExample/AuthorizationExample.csproj create mode 100644 samples/AuthorizationExample/Program.cs create mode 100644 src/ModelContextProtocol/McpAuthorizationException.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/ClientMetadata.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs create mode 100644 src/ModelContextProtocol/Utils/SynchronizedValue.cs diff --git a/.gitignore b/.gitignore index 171615f9..a9ca39b1 100644 --- a/.gitignore +++ b/.gitignore @@ -80,4 +80,7 @@ docs/api # Rider .idea/ -.idea_modules/ \ No newline at end of file +.idea_modules/ + +# Specs +.specs/ \ No newline at end of file diff --git a/samples/AuthorizationExample/AuthorizationExample.csproj b/samples/AuthorizationExample/AuthorizationExample.csproj new file mode 100644 index 00000000..60091057 --- /dev/null +++ b/samples/AuthorizationExample/AuthorizationExample.csproj @@ -0,0 +1,14 @@ + + + + Exe + net8.0 + enable + enable + + + + + + + \ No newline at end of file diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs new file mode 100644 index 00000000..2594863a --- /dev/null +++ b/samples/AuthorizationExample/Program.cs @@ -0,0 +1,81 @@ +using System.Diagnostics; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; + +namespace AuthorizationExample; + +/// +/// Example demonstrating how to use the MCP C# SDK with OAuth authorization. +/// +public class Program +{ + public static async Task Main(string[] args) + { + // Define the MCP server endpoint that requires OAuth authentication + var serverEndpoint = new Uri("https://example.com/mcp"); + + // Set up the SSE transport with authorization support + var transportOptions = new SseClientTransportOptions + { + Endpoint = serverEndpoint, + + // Provide a callback to handle the authorization flow + AuthorizeCallback = async (clientMetadata) => + { + Console.WriteLine("Authentication required. Opening browser for authorization..."); + + // In a real app, you'd likely have a local HTTP server to receive the callback + // This is just a simplified example + Console.WriteLine("Once you've authorized in the browser, enter the code and redirect URI:"); + Console.Write("Code: "); + var code = Console.ReadLine() ?? ""; + Console.Write("Redirect URI: "); + var redirectUri = Console.ReadLine() ?? "http://localhost:8888/callback"; + + return (redirectUri, code); + } + + // Alternatively, use the built-in local server handler: + // AuthorizeCallback = SseClientTransport.CreateLocalServerAuthorizeCallback( + // openBrowser: async (url) => + // { + // // Open the URL in the user's default browser + // Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }); + // } + // ) + }; + + try + { + // Create the client with authorization-enabled transport + var transport = new SseClientTransport(transportOptions); + var client = await McpClient.CreateAsync(transport); + + // Use the MCP client normally - authorization is handled automatically + // If the server returns a 401 Unauthorized response, the authorization flow will be triggered + var result = await client.PingAsync(); + Console.WriteLine($"Server ping successful: {result.ServerInfo.Name} {result.ServerInfo.Version}"); + + // Example tool call + var weatherPrompt = "What's the weather like today?"; + var weatherResult = await client.CompletionCompleteAsync( + new CompletionCompleteRequestBuilder(weatherPrompt).Build()); + + Console.WriteLine($"Response: {weatherResult.Content.Text}"); + } + catch (McpAuthorizationException authEx) + { + Console.WriteLine($"Authorization error: {authEx.Message}"); + Console.WriteLine($"Resource: {authEx.ResourceUri}"); + Console.WriteLine($"Auth server: {authEx.AuthorizationServerUri}"); + } + catch (McpException mcpEx) + { + Console.WriteLine($"MCP error: {mcpEx.Message} (Error code: {mcpEx.ErrorCode})"); + } + catch (Exception ex) + { + Console.WriteLine($"Unexpected error: {ex.Message}"); + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/McpAuthorizationException.cs b/src/ModelContextProtocol/McpAuthorizationException.cs new file mode 100644 index 00000000..6dca9a2f --- /dev/null +++ b/src/ModelContextProtocol/McpAuthorizationException.cs @@ -0,0 +1,77 @@ +namespace ModelContextProtocol; + +/// +/// Represents an exception that is thrown when an authorization or authentication error occurs in MCP. +/// +/// +/// This exception is thrown when the client fails to authenticate with an MCP server that requires +/// authentication, such as when the OAuth authorization flow fails or when the server rejects the provided credentials. +/// +public class McpAuthorizationException : McpException +{ + /// + /// Initializes a new instance of the class. + /// + public McpAuthorizationException() + : base("Authorization failed", McpErrorCode.Unauthorized) + { + } + + /// + /// Initializes a new instance of the class with a specified error message. + /// + /// The message that describes the error. + public McpAuthorizationException(string message) + : base(message, McpErrorCode.Unauthorized) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + public McpAuthorizationException(string message, Exception? innerException) + : base(message, innerException, McpErrorCode.Unauthorized) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and error code. + /// + /// The message that describes the error. + /// The MCP error code. Should be either or . + public McpAuthorizationException(string message, McpErrorCode errorCode) + : base(message, errorCode) + { + if (errorCode != McpErrorCode.Unauthorized && errorCode != McpErrorCode.AuthenticationFailed) + { + throw new ArgumentException($"Error code must be either {nameof(McpErrorCode.Unauthorized)} or {nameof(McpErrorCode.AuthenticationFailed)}", nameof(errorCode)); + } + } + + /// + /// Initializes a new instance of the class with a specified error message, inner exception, and error code. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + /// The MCP error code. Should be either or . + public McpAuthorizationException(string message, Exception? innerException, McpErrorCode errorCode) + : base(message, innerException, errorCode) + { + if (errorCode != McpErrorCode.Unauthorized && errorCode != McpErrorCode.AuthenticationFailed) + { + throw new ArgumentException($"Error code must be either {nameof(McpErrorCode.Unauthorized)} or {nameof(McpErrorCode.AuthenticationFailed)}", nameof(errorCode)); + } + } + + /// + /// Gets or sets the resource that requires authorization. + /// + public string? ResourceUri { get; set; } + + /// + /// Gets or sets the authorization server URI that should be used for authentication. + /// + public string? AuthorizationServerUri { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/McpErrorCode.cs b/src/ModelContextProtocol/McpErrorCode.cs index f6cf4f51..69fcc741 100644 --- a/src/ModelContextProtocol/McpErrorCode.cs +++ b/src/ModelContextProtocol/McpErrorCode.cs @@ -46,4 +46,20 @@ public enum McpErrorCode /// This error is used when the endpoint encounters an unexpected condition that prevents it from fulfilling the request. /// InternalError = -32603, + + /// + /// Indicates that the client is not authorized to access the requested resource. + /// + /// + /// This error is returned when the client lacks the necessary credentials or permissions to access a resource. + /// + Unauthorized = -32401, + + /// + /// Indicates that the authentication process failed. + /// + /// + /// This error is returned when the client provides invalid or expired credentials, or when the authentication flow fails. + /// + AuthenticationFailed = -32402, } diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs new file mode 100644 index 00000000..e65dd4ff --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs @@ -0,0 +1,91 @@ +using System.Diagnostics; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the context for authorization in an MCP client session. +/// +internal class AuthorizationContext +{ + /// + /// Gets or sets the resource metadata. + /// + public ResourceMetadata? ResourceMetadata { get; set; } + + /// + /// Gets or sets the authorization server metadata. + /// + public AuthorizationServerMetadata? AuthorizationServerMetadata { get; set; } + + /// + /// Gets or sets the client registration response. + /// + public ClientRegistrationResponse? ClientRegistration { get; set; } + + /// + /// Gets or sets the token response. + /// + public TokenResponse? TokenResponse { get; set; } + + /// + /// Gets or sets the code verifier for PKCE. + /// + public string? CodeVerifier { get; set; } + + /// + /// Gets or sets the redirect URI used in the authorization flow. + /// + public string? RedirectUri { get; set; } + + /// + /// Gets or sets the time when the access token was issued. + /// + public DateTimeOffset? TokenIssuedAt { get; set; } + + /// + /// Gets a value indicating whether the access token is valid. + /// + public bool HasValidToken => TokenResponse != null && + (TokenResponse.ExpiresIn == null || + TokenIssuedAt != null && + DateTimeOffset.UtcNow < TokenIssuedAt.Value.AddSeconds(TokenResponse.ExpiresIn.Value - 60)); // 1 minute buffer + + /// + /// Gets the access token for authentication. + /// + /// The access token if available, otherwise null. + public string? GetAccessToken() + { + if (!HasValidToken) + { + return null; + } + + // Since HasValidToken checks that TokenResponse isn't null, we should never have null here, + // but we'll add an explicit null check to satisfy the compiler + return TokenResponse?.AccessToken; + } + + /// + /// Gets a value indicating whether a refresh token is available for refreshing the access token. + /// + public bool CanRefreshToken => TokenResponse?.RefreshToken != null && + ClientRegistration != null && + AuthorizationServerMetadata != null; + + /// + /// Validates the URL of a resource against the resource URL from the metadata. + /// + /// The URL to validate. + /// True if the URLs match, otherwise false. + public bool ValidateResourceUrl(string resourceUrl) + { + if (ResourceMetadata == null || string.IsNullOrEmpty(ResourceMetadata.Resource)) + { + return false; + } + + // Resource URL must match exactly + return string.Equals(resourceUrl, ResourceMetadata.Resource, StringComparison.OrdinalIgnoreCase); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs new file mode 100644 index 00000000..9be69e67 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs @@ -0,0 +1,69 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents OAuth 2.0 authorization server metadata as defined in RFC 8414. +/// +internal class AuthorizationServerMetadata +{ + /// + /// Gets or sets the authorization endpoint URL. + /// + [JsonPropertyName("authorization_endpoint")] + public required string AuthorizationEndpoint { get; set; } + + /// + /// Gets or sets the token endpoint URL. + /// + [JsonPropertyName("token_endpoint")] + public required string TokenEndpoint { get; set; } + + /// + /// Gets or sets the client registration endpoint URL. + /// + [JsonPropertyName("registration_endpoint")] + public string? RegistrationEndpoint { get; set; } + + /// + /// Gets or sets the token revocation endpoint URL. + /// + [JsonPropertyName("revocation_endpoint")] + public string? RevocationEndpoint { get; set; } + + /// + /// Gets or sets the response types supported by the authorization server. + /// + [JsonPropertyName("response_types_supported")] + public string[]? ResponseTypesSupported { get; set; } = ["code"]; + + /// + /// Gets or sets the grant types supported by the authorization server. + /// + [JsonPropertyName("grant_types_supported")] + public string[]? GrantTypesSupported { get; set; } = ["authorization_code", "refresh_token"]; + + /// + /// Gets or sets the token endpoint authentication methods supported by the authorization server. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + public string[]? TokenEndpointAuthMethodsSupported { get; set; } = ["client_secret_basic"]; + + /// + /// Gets or sets the code challenge methods supported by the authorization server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + public string[]? CodeChallengeMethodsSupported { get; set; } = ["S256"]; + + /// + /// Gets or sets the issuer identifier. + /// + [JsonPropertyName("issuer")] + public string? Issuer { get; set; } + + /// + /// Gets or sets the scopes supported by the authorization server. + /// + [JsonPropertyName("scopes_supported")] + public string[]? ScopesSupported { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs new file mode 100644 index 00000000..e518dd4c --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs @@ -0,0 +1,423 @@ +using System.Net; +using System.Net.Http.Headers; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Provides OAuth 2.0 authorization services for MCP clients. +/// +internal class AuthorizationService +{ + private static readonly HttpClient s_httpClient = new() + { + DefaultRequestHeaders = + { + Accept = { new MediaTypeWithQualityHeaderValue("application/json") } + } + }; + + /// + /// Gets resource metadata from a 401 Unauthorized response. + /// + /// The HTTP response that contains the WWW-Authenticate header. + /// A that represents the asynchronous operation. The task result contains the resource metadata if available. + public static async Task GetResourceMetadataFromResponseAsync(HttpResponseMessage response) + { + if (response.StatusCode != HttpStatusCode.Unauthorized) + { + return null; + } + + // Get the WWW-Authenticate header + if (!response.Headers.TryGetValues("WWW-Authenticate", out var authenticateValues)) + { + return null; + } + + // Parse the WWW-Authenticate header + string? resourceMetadataUrl = null; + foreach (var value in authenticateValues) + { + if (value.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + var parameters = ParseAuthHeaderParameters(value["Bearer ".Length..].Trim()); + + if (parameters.TryGetValue("resource_metadata", out var metadataUrl)) + { + resourceMetadataUrl = metadataUrl; + break; + } + } + } + + if (string.IsNullOrEmpty(resourceMetadataUrl)) + { + return null; + } + + // Fetch the resource metadata document + try + { + using var metadataResponse = await s_httpClient.GetAsync(resourceMetadataUrl); + metadataResponse.EnsureSuccessStatusCode(); + + return await JsonSerializer.DeserializeAsync( + await metadataResponse.Content.ReadAsStreamAsync(), + new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }); + } + catch (Exception) + { + // Failed to get resource metadata + return null; + } + } + + /// + /// Discovers authorization server metadata from a well-known endpoint. + /// + /// The base URL of the authorization server. + /// A that represents the asynchronous operation. The task result contains the authorization server metadata. + /// Thrown when both well-known endpoints return errors. + public static async Task DiscoverAuthorizationServerMetadataAsync(string authorizationServerUrl) + { + Throw.IfNullOrWhiteSpace(authorizationServerUrl); + + // Remove trailing slash if present + if (authorizationServerUrl.EndsWith("/")) + { + authorizationServerUrl = authorizationServerUrl[..^1]; + } + + // Try OpenID Connect discovery endpoint + var openIdConfigUrl = $"{authorizationServerUrl}/.well-known/openid-configuration"; + try + { + using var openIdResponse = await s_httpClient.GetAsync(openIdConfigUrl); + if (openIdResponse.IsSuccessStatusCode) + { + return await JsonSerializer.DeserializeAsync( + await openIdResponse.Content.ReadAsStreamAsync(), + new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }) ?? throw new InvalidOperationException("Failed to parse authorization server metadata"); + } + } + catch (Exception ex) when (ex is not InvalidOperationException) + { + // Failed to get OpenID configuration, try OAuth endpoint + } + + // Try OAuth 2.0 Authorization Server Metadata endpoint + var oauthConfigUrl = $"{authorizationServerUrl}/.well-known/oauth-authorization-server"; + try + { + using var oauthResponse = await s_httpClient.GetAsync(oauthConfigUrl); + if (oauthResponse.IsSuccessStatusCode) + { + return await JsonSerializer.DeserializeAsync( + await oauthResponse.Content.ReadAsStreamAsync(), + new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }) ?? throw new InvalidOperationException("Failed to parse authorization server metadata"); + } + } + catch (Exception ex) when (ex is not InvalidOperationException) + { + // Failed to get OAuth configuration + } + + throw new InvalidOperationException( + "Failed to discover authorization server metadata. " + + "Neither OpenID Connect nor OAuth 2.0 well-known endpoints are available."); + } + + /// + /// Registers a client with the authorization server. + /// + /// The authorization server metadata. + /// The client metadata for registration. + /// A that represents the asynchronous operation. The task result contains the client registration response. + /// Thrown when the authorization server does not support dynamic client registration. + public static async Task RegisterClientAsync( + AuthorizationServerMetadata metadata, + ClientMetadata clientMetadata) + { + Throw.IfNull(metadata); + Throw.IfNull(clientMetadata); + + if (metadata.RegistrationEndpoint == null) + { + throw new InvalidOperationException("The authorization server does not support dynamic client registration."); + } + + var content = new StringContent( + JsonSerializer.Serialize(clientMetadata), + Encoding.UTF8, + "application/json"); + + using var response = await s_httpClient.PostAsync(metadata.RegistrationEndpoint, content); + response.EnsureSuccessStatusCode(); + + return await JsonSerializer.DeserializeAsync( + await response.Content.ReadAsStreamAsync(), + new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }) ?? throw new InvalidOperationException("Failed to parse client registration response"); + } + + /// + /// Generates a code verifier and code challenge for PKCE. + /// + /// A tuple containing the code verifier and code challenge. + public static (string CodeVerifier, string CodeChallenge) GeneratePkceValues() + { + // Generate a random code verifier + var bytes = new byte[32]; + using var rng = RandomNumberGenerator.Create(); + rng.GetBytes(bytes); + var codeVerifier = Convert.ToBase64String(bytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + + // Generate the code challenge (S256) + using var sha256 = SHA256.Create(); + var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); + var codeChallenge = Convert.ToBase64String(challengeBytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + + return (codeVerifier, codeChallenge); + } + + /// + /// Creates an authorization URL for the OAuth authorization code flow with PKCE. + /// + /// The authorization server metadata. + /// The client identifier. + /// The redirect URI. + /// The code challenge for PKCE. + /// The requested scopes. + /// A value used to maintain state between the request and callback. + /// The authorization URL. + public static string CreateAuthorizationUrl( + AuthorizationServerMetadata metadata, + string clientId, + string redirectUri, + string codeChallenge, + string[]? scopes = null, + string? state = null) + { + Throw.IfNull(metadata); + Throw.IfNullOrWhiteSpace(clientId); + Throw.IfNullOrWhiteSpace(redirectUri); + Throw.IfNullOrWhiteSpace(codeChallenge); + + var queryBuilder = new StringBuilder(metadata.AuthorizationEndpoint); + queryBuilder.Append(metadata.AuthorizationEndpoint.Contains('?') ? '&' : '?'); + queryBuilder.Append("response_type=code"); + queryBuilder.Append($"&client_id={Uri.EscapeDataString(clientId)}"); + queryBuilder.Append($"&redirect_uri={Uri.EscapeDataString(redirectUri)}"); + queryBuilder.Append($"&code_challenge={Uri.EscapeDataString(codeChallenge)}"); + queryBuilder.Append("&code_challenge_method=S256"); + + if (scopes != null && scopes.Length > 0) + { + queryBuilder.Append($"&scope={Uri.EscapeDataString(string.Join(" ", scopes))}"); + } + + if (!string.IsNullOrEmpty(state)) + { + queryBuilder.Append($"&state={Uri.EscapeDataString(state)}"); + } + + return queryBuilder.ToString(); + } + + /// + /// Exchanges an authorization code for tokens. + /// + /// The authorization server metadata. + /// The client identifier. + /// The client secret. + /// The redirect URI. + /// The authorization code received from the authorization server. + /// The code verifier for PKCE. + /// A that represents the asynchronous operation. The task result contains the token response. + public static async Task ExchangeCodeForTokensAsync( + AuthorizationServerMetadata metadata, + string clientId, + string? clientSecret, + string redirectUri, + string code, + string codeVerifier) + { + Throw.IfNull(metadata); + Throw.IfNullOrWhiteSpace(clientId); + Throw.IfNullOrWhiteSpace(redirectUri); + Throw.IfNullOrWhiteSpace(code); + Throw.IfNullOrWhiteSpace(codeVerifier); + + var tokenRequestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "authorization_code", + ["code"] = code, + ["redirect_uri"] = redirectUri, + ["client_id"] = clientId, + ["code_verifier"] = codeVerifier + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, metadata.TokenEndpoint) + { + Content = tokenRequestContent + }; + + // Add client authentication if client secret is provided + if (!string.IsNullOrEmpty(clientSecret)) + { + var credentials = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{clientId}:{clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", credentials); + } + + using var response = await s_httpClient.SendAsync(request); + response.EnsureSuccessStatusCode(); + + return await JsonSerializer.DeserializeAsync( + await response.Content.ReadAsStreamAsync(), + new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }) ?? throw new InvalidOperationException("Failed to parse token response"); + } + + /// + /// Refreshes an access token using a refresh token. + /// + /// The authorization server metadata. + /// The client identifier. + /// The client secret. + /// The refresh token. + /// A that represents the asynchronous operation. The task result contains the token response. + public static async Task RefreshTokenAsync( + AuthorizationServerMetadata metadata, + string clientId, + string? clientSecret, + string refreshToken) + { + Throw.IfNull(metadata); + Throw.IfNullOrWhiteSpace(clientId); + Throw.IfNullOrWhiteSpace(refreshToken); + + var tokenRequestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "refresh_token", + ["refresh_token"] = refreshToken, + ["client_id"] = clientId + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, metadata.TokenEndpoint) + { + Content = tokenRequestContent + }; + + // Add client authentication if client secret is provided + if (!string.IsNullOrEmpty(clientSecret)) + { + var credentials = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{clientId}:{clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", credentials); + } + + using var response = await s_httpClient.SendAsync(request); + response.EnsureSuccessStatusCode(); + + return await JsonSerializer.DeserializeAsync( + await response.Content.ReadAsStreamAsync(), + new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }) ?? throw new InvalidOperationException("Failed to parse token response"); + } + + private static Dictionary ParseAuthHeaderParameters(string parameters) + { + var result = new Dictionary(StringComparer.OrdinalIgnoreCase); + + var start = 0; + while (start < parameters.Length) + { + // Find the next key=value pair + var equalPos = parameters.IndexOf('=', start); + if (equalPos == -1) + break; + + var key = parameters[start..equalPos].Trim(); + start = equalPos + 1; + + // Check if the value is quoted + if (start < parameters.Length && parameters[start] == '"') + { + start++; // Skip the opening quote + + // Find the closing quote + var endQuote = start; + while (endQuote < parameters.Length) + { + endQuote = parameters.IndexOf('"', endQuote); + if (endQuote == -1) + break; + + // Check if this is an escaped quote + if (endQuote > 0 && parameters[endQuote - 1] == '\\') + { + endQuote++; // Skip the escaped quote + continue; + } + + break; // Found a non-escaped quote + } + + if (endQuote == -1) + endQuote = parameters.Length; // No closing quote found, use the rest of the string + + var value = parameters[start..endQuote]; + result[key] = value.Replace("\\\"", "\""); // Unescape quotes + + // Move past the closing quote and any following comma + start = endQuote + 1; + var commaPos = parameters.IndexOf(',', start); + if (commaPos != -1) + start = commaPos + 1; + else + break; + } + else + { + // Unquoted value, ends at the next comma or end of string + var commaPos = parameters.IndexOf(',', start); + var value = commaPos != -1 + ? parameters[start..commaPos].Trim() + : parameters[start..].Trim(); + + result[key] = value; + + if (commaPos == -1) + break; + + start = commaPos + 1; + } + } + + return result; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ClientMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/ClientMetadata.cs new file mode 100644 index 00000000..d650c5a2 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/ClientMetadata.cs @@ -0,0 +1,99 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the OAuth 2.0 client registration metadata as defined in RFC 7591. +/// +public class ClientMetadata +{ + /// + /// Gets or sets the array of redirection URI strings for use in redirect-based flows. + /// + [JsonPropertyName("redirect_uris")] + public required string[] RedirectUris { get; set; } + + /// + /// Gets or sets the requested authentication method for the token endpoint. + /// + [JsonPropertyName("token_endpoint_auth_method")] + public string? TokenEndpointAuthMethod { get; set; } = "client_secret_basic"; + + /// + /// Gets or sets the array of OAuth 2.0 grant type strings that the client can use at the token endpoint. + /// + [JsonPropertyName("grant_types")] + public string[]? GrantTypes { get; set; } = ["authorization_code", "refresh_token"]; + + /// + /// Gets or sets the array of the OAuth 2.0 response type strings that the client can use at the authorization endpoint. + /// + [JsonPropertyName("response_types")] + public string[]? ResponseTypes { get; set; } = ["code"]; + + /// + /// Gets or sets the human-readable string name of the client. + /// + [JsonPropertyName("client_name")] + public string? ClientName { get; set; } + + /// + /// Gets or sets the URL string of a web page providing information about the client. + /// + [JsonPropertyName("client_uri")] + public string? ClientUri { get; set; } + + /// + /// Gets or sets the URL string that references a logo for the client. + /// + [JsonPropertyName("logo_uri")] + public string? LogoUri { get; set; } + + /// + /// Gets or sets the string containing a space-separated list of scope values that the client can use. + /// + [JsonPropertyName("scope")] + public string? Scope { get; set; } + + /// + /// Gets or sets the array of strings representing ways to contact people responsible for this client. + /// + [JsonPropertyName("contacts")] + public string[]? Contacts { get; set; } + + /// + /// Gets or sets the URL string that points to a human-readable terms of service document. + /// + [JsonPropertyName("tos_uri")] + public string? TosUri { get; set; } + + /// + /// Gets or sets the URL string that points to a human-readable privacy policy document. + /// + [JsonPropertyName("policy_uri")] + public string? PolicyUri { get; set; } + + /// + /// Gets or sets the URL string referencing the client's JSON Web Key Set document. + /// + [JsonPropertyName("jwks_uri")] + public string? JwksUri { get; set; } + + /// + /// Gets or sets the client's JSON Web Key Set document value. + /// + [JsonPropertyName("jwks")] + public object? Jwks { get; set; } + + /// + /// Gets or sets a unique identifier string assigned by the client developer or software publisher. + /// + [JsonPropertyName("software_id")] + public string? SoftwareId { get; set; } + + /// + /// Gets or sets the version identifier string for the client software. + /// + [JsonPropertyName("software_version")] + public string? SoftwareVersion { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs b/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs new file mode 100644 index 00000000..06cef8b5 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs @@ -0,0 +1,33 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the OAuth 2.0 client registration response as defined in RFC 7591. +/// +internal class ClientRegistrationResponse +{ + /// + /// Gets or sets the OAuth 2.0 client identifier string. + /// + [JsonPropertyName("client_id")] + public required string ClientId { get; set; } + + /// + /// Gets or sets the OAuth 2.0 client secret string. + /// + [JsonPropertyName("client_secret")] + public string? ClientSecret { get; set; } + + /// + /// Gets or sets the time at which the client identifier was issued. + /// + [JsonPropertyName("client_id_issued_at")] + public long? ClientIdIssuedAt { get; set; } + + /// + /// Gets or sets the time at which the client secret will expire or 0 if it will not expire. + /// + [JsonPropertyName("client_secret_expires_at")] + public long? ClientSecretExpiresAt { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs new file mode 100644 index 00000000..ca1de7b9 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -0,0 +1,265 @@ +using System.Diagnostics; +using System.Net; +using System.Net.Http.Headers; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Provides authorization handling for MCP clients. +/// +internal class DefaultAuthorizationHandler : IAuthorizationHandler +{ + private readonly ILogger _logger; + private readonly SynchronizedValue _authContext = new(new AuthorizationContext()); + private readonly Func>? _authorizeCallback; + + /// + /// Initializes a new instance of the class. + /// + /// The logger factory. + /// A callback function that handles the authorization code flow. + public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, Func>? authorizeCallback = null) + { + _logger = loggerFactory != null + ? loggerFactory.CreateLogger() + : NullLogger.Instance; + _authorizeCallback = authorizeCallback; + } + + /// + public async Task AuthenticateRequestAsync(HttpRequestMessage request) + { + // Try to get a valid token, refreshing if necessary + var token = await GetValidTokenAsync(); + + if (!string.IsNullOrEmpty(token)) + { + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); + } + } + + /// + public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage response, Uri serverUri) + { + if (response.StatusCode != HttpStatusCode.Unauthorized) + { + return false; + } + + _logger.LogDebug("Received 401 Unauthorized response from {ServerUri}", serverUri); + + using var authContext = await _authContext.LockAsync(); + + // If we already have a valid token, it might be that the token was just revoked + // or has other issues - we need to clear our state and retry the authorization flow + if (authContext.Value.HasValidToken) + { + _logger.LogWarning("Server rejected our authentication token. Clearing authentication state and reauthorizing."); + authContext.Value = new AuthorizationContext(); + } + + // Try to get resource metadata from the response + var resourceMetadata = await AuthorizationService.GetResourceMetadataFromResponseAsync(response); + if (resourceMetadata == null) + { + _logger.LogWarning("Failed to extract resource metadata from 401 response"); + + // Create a more specific exception + var exception = new McpAuthorizationException("Authorization required but no resource metadata available") + { + ResourceUri = serverUri.ToString() + }; + throw exception; + } + + // Validate that the resource matches the server FQDN + if (!authContext.Value.ValidateResourceUrl(serverUri.ToString()) && + !string.Equals(resourceMetadata.Resource, serverUri.ToString(), StringComparison.OrdinalIgnoreCase)) + { + _logger.LogWarning("Resource URL mismatch: expected {Expected}, got {Actual}", + serverUri, resourceMetadata.Resource); + + var exception = new McpAuthorizationException($"Resource URL mismatch: expected {serverUri}, got {resourceMetadata.Resource}"); + exception.ResourceUri = resourceMetadata.Resource; + throw exception; + } + + authContext.Value.ResourceMetadata = resourceMetadata; + + // Get the first authorization server from the metadata + if (resourceMetadata.AuthorizationServers == null || resourceMetadata.AuthorizationServers.Length == 0) + { + _logger.LogWarning("No authorization servers found in resource metadata"); + + var exception = new McpAuthorizationException("No authorization servers available"); + exception.ResourceUri = resourceMetadata.Resource; + throw exception; + } + + var authServerUrl = resourceMetadata.AuthorizationServers[0]; + _logger.LogDebug("Using authorization server: {AuthServerUrl}", authServerUrl); + + try + { + // Discover authorization server metadata + var authServerMetadata = await AuthorizationService.DiscoverAuthorizationServerMetadataAsync(authServerUrl); + authContext.Value.AuthorizationServerMetadata = authServerMetadata; + _logger.LogDebug("Successfully retrieved authorization server metadata"); + + // Create client metadata + var clientMetadata = new ClientMetadata + { + RedirectUris = new[] { "http://localhost:8888/callback" }, // Default redirect URI + ClientName = "MCP C# SDK Client", + Scope = string.Join(" ", resourceMetadata.ScopesSupported ?? Array.Empty()) + }; + + // Register client if the server supports it + if (authServerMetadata.RegistrationEndpoint != null) + { + _logger.LogDebug("Registering client with authorization server"); + var clientRegistration = await AuthorizationService.RegisterClientAsync(authServerMetadata, clientMetadata); + authContext.Value.ClientRegistration = clientRegistration; + _logger.LogDebug("Client registered successfully with ID: {ClientId}", clientRegistration.ClientId); + } + else + { + _logger.LogWarning("Authorization server does not support dynamic client registration"); + + var exception = new McpAuthorizationException("Authorization server does not support dynamic client registration"); + exception.ResourceUri = resourceMetadata.Resource; + exception.AuthorizationServerUri = authServerUrl; + throw exception; + } + + // If we have no way to handle user authorization, we can't proceed + if (_authorizeCallback == null) + { + _logger.LogWarning("No authorization callback provided, can't proceed with OAuth flow"); + + var exception = new McpAuthorizationException( + "Authentication is required but no authorization callback was provided. " + + "Use SseClientTransportOptions.AuthorizeCallback to provide a callback function."); + exception.ResourceUri = resourceMetadata.Resource; + exception.AuthorizationServerUri = authServerUrl; + throw exception; + } + + // Generate PKCE values + var (codeVerifier, codeChallenge) = AuthorizationService.GeneratePkceValues(); + authContext.Value.CodeVerifier = codeVerifier; + + // Initiate authorization code flow + _logger.LogDebug("Initiating authorization code flow"); + + // Get the registered client ID + var clientId = authContext.Value.ClientRegistration!.ClientId; + + // Get the authorization URL that the user needs to visit + var authUrl = AuthorizationService.CreateAuthorizationUrl( + authServerMetadata, + clientId, + clientMetadata.RedirectUris[0], + codeChallenge, + resourceMetadata.ScopesSupported); + + _logger.LogDebug("Authorization URL: {AuthUrl}", authUrl); + + // Let the callback handle the user authorization + var (redirectUri, code) = await _authorizeCallback(clientMetadata); + authContext.Value.RedirectUri = redirectUri; + + // Exchange the code for tokens + _logger.LogDebug("Exchanging authorization code for tokens"); + var tokenResponse = await AuthorizationService.ExchangeCodeForTokensAsync( + authServerMetadata, + clientId, + authContext.Value.ClientRegistration.ClientSecret, + redirectUri, + code, + codeVerifier); + + authContext.Value.TokenResponse = tokenResponse; + authContext.Value.TokenIssuedAt = DateTimeOffset.UtcNow; + + _logger.LogDebug("Successfully obtained access token"); + return true; + } + catch (Exception ex) when (ex is not McpAuthorizationException) + { + _logger.LogError(ex, "Failed to complete authorization flow"); + + var authException = new McpAuthorizationException( + $"Failed to complete authorization flow: {ex.Message}", ex, McpErrorCode.AuthenticationFailed); + + authException.ResourceUri = resourceMetadata.Resource; + authException.AuthorizationServerUri = authServerUrl; + + throw authException; + } + } + + private async Task GetValidTokenAsync() + { + using var authContext = await _authContext.LockAsync(); + + // If we have a valid token, use it + if (authContext.Value.HasValidToken) + { + _logger.LogDebug("Using existing valid access token"); + return authContext.Value.GetAccessToken(); + } + + // If we can refresh the token, do so + if (authContext.Value.CanRefreshToken) + { + try + { + _logger.LogDebug("Refreshing access token"); + + // Null checks to ensure parameters are valid + if (authContext.Value.AuthorizationServerMetadata == null) + { + _logger.LogError("Cannot refresh token: AuthorizationServerMetadata is null"); + return null; + } + + if (authContext.Value.ClientRegistration == null) + { + _logger.LogError("Cannot refresh token: ClientRegistration is null"); + return null; + } + + if (authContext.Value.TokenResponse?.RefreshToken == null) + { + _logger.LogError("Cannot refresh token: RefreshToken is null"); + return null; + } + + var tokenResponse = await AuthorizationService.RefreshTokenAsync( + authContext.Value.AuthorizationServerMetadata, + authContext.Value.ClientRegistration.ClientId, + authContext.Value.ClientRegistration.ClientSecret, + authContext.Value.TokenResponse.RefreshToken); + + authContext.Value.TokenResponse = tokenResponse; + authContext.Value.TokenIssuedAt = DateTimeOffset.UtcNow; + + _logger.LogDebug("Successfully refreshed access token"); + return tokenResponse.AccessToken; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to refresh access token"); + // Clear the token so we'll try to reauthenticate on the next request + authContext.Value.TokenResponse = null; + authContext.Value.TokenIssuedAt = null; + } + } + + return null; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs new file mode 100644 index 00000000..85f5a61d --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs @@ -0,0 +1,24 @@ +using System.Net; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Defines methods for handling authorization in an MCP client. +/// +public interface IAuthorizationHandler +{ + /// + /// Handles authentication for HTTP requests. + /// + /// The HTTP request to authenticate. + /// A representing the asynchronous operation. + Task AuthenticateRequestAsync(HttpRequestMessage request); + + /// + /// Handles a 401 Unauthorized response. + /// + /// The HTTP response that contains the 401 status code. + /// The URI of the server that returned the 401 status code. + /// A that represents the asynchronous operation. The task result contains true if the authentication was successful and the request should be retried, otherwise false. + Task HandleUnauthorizedResponseAsync(HttpResponseMessage response, Uri serverUri); +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs new file mode 100644 index 00000000..bf6613a1 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the resource metadata from the WWW-Authenticate header in a 401 Unauthorized response. +/// +internal class ResourceMetadata +{ + /// + /// Gets or sets the resource identifier URI. + /// + [JsonPropertyName("resource")] + public required string Resource { get; set; } + + /// + /// Gets or sets the authorization servers that can be used for authentication. + /// + [JsonPropertyName("authorization_servers")] + public required string[] AuthorizationServers { get; set; } + + /// + /// Gets or sets the bearer token methods supported by the resource. + /// + [JsonPropertyName("bearer_methods_supported")] + public string[]? BearerMethodsSupported { get; set; } + + /// + /// Gets or sets the scopes supported by the resource. + /// + [JsonPropertyName("scopes_supported")] + public string[]? ScopesSupported { get; set; } + + /// + /// Gets or sets the URL to the resource documentation. + /// + [JsonPropertyName("resource_documentation")] + public string? ResourceDocumentation { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs b/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs new file mode 100644 index 00000000..2c5faefe --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the OAuth 2.0 token response as defined in RFC 6749. +/// +internal class TokenResponse +{ + /// + /// Gets or sets the access token issued by the authorization server. + /// + [JsonPropertyName("access_token")] + public required string AccessToken { get; set; } + + /// + /// Gets or sets the type of the token issued. + /// + [JsonPropertyName("token_type")] + public required string TokenType { get; set; } + + /// + /// Gets or sets the lifetime in seconds of the access token. + /// + [JsonPropertyName("expires_in")] + public long? ExpiresIn { get; set; } + + /// + /// Gets or sets the refresh token, which can be used to obtain new access tokens. + /// + [JsonPropertyName("refresh_token")] + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the scope of the access token. + /// + [JsonPropertyName("scope")] + public string? Scope { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 5d952f8a..f77b49fa 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -1,9 +1,11 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Diagnostics; +using System.Net; using System.Net.Http.Headers; using System.Net.ServerSentEvents; using System.Text; @@ -24,6 +26,7 @@ internal sealed partial class SseClientSessionTransport : TransportBase private Task? _receiveTask; private readonly ILogger _logger; private readonly TaskCompletionSource _connectionEstablished; + private readonly IAuthorizationHandler _authorizationHandler; /// /// SSE transport for client endpoints. Unlike stdio it does not launch a process, but connects to an existing server. @@ -45,6 +48,10 @@ public SseClientSessionTransport(SseClientTransportOptions transportOptions, Htt _connectionCts = new CancellationTokenSource(); _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _connectionEstablished = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + // Initialize the authorization handler + _authorizationHandler = transportOptions.AuthorizationHandler ?? + new DefaultAuthorizationHandler(loggerFactory, transportOptions.AuthorizeCallback); } /// @@ -87,56 +94,108 @@ public override async Task SendMessageAsync( messageId = messageWithId.Id.ToString(); } - var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) + using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) { Content = content, }; + + // Add authorization headers if needed + await _authorizationHandler.AuthenticateRequestAsync(httpRequestMessage).ConfigureAwait(false); + + // Copy additional headers CopyAdditionalHeaders(httpRequestMessage.Headers); - var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + + // Send the request, handling potential auth challenges + HttpResponseMessage? response = null; + bool authRetry = false; + + do + { + authRetry = false; + response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + + // Handle 401 Unauthorized response + if (response.StatusCode == HttpStatusCode.Unauthorized) + { + // Try to handle the unauthorized response + authRetry = await _authorizationHandler.HandleUnauthorizedResponseAsync( + response, _messageEndpoint).ConfigureAwait(false); + + if (authRetry) + { + // Create a new request (we can't reuse the previous one) + using var newRequest = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) + { + Content = new StringContent( + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), + Encoding.UTF8, + "application/json" + ) + }; + + // Add authorization headers for the new request + await _authorizationHandler.AuthenticateRequestAsync(newRequest).ConfigureAwait(false); + CopyAdditionalHeaders(newRequest.Headers); + + // Dispose the previous response + response.Dispose(); + + // Send the new request + response = await _httpClient.SendAsync(newRequest, cancellationToken).ConfigureAwait(false); + } + } + } while (authRetry); - response.EnsureSuccessStatusCode(); + try + { + response.EnsureSuccessStatusCode(); - var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - // Check if the message was an initialize request - if (message is JsonRpcRequest request && request.Method == RequestMethods.Initialize) - { - // If the response is not a JSON-RPC response, it is an SSE message + // Check if the message was an initialize request + if (message is JsonRpcRequest request && request.Method == RequestMethods.Initialize) + { + // If the response is not a JSON-RPC response, it is an SSE message + if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) + { + LogAcceptedPost(Name, messageId); + // The response will arrive as an SSE message + } + else + { + JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ?? + throw new InvalidOperationException("Failed to initialize client"); + + LogTransportReceivedMessage(Name, messageId); + await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false); + LogTransportMessageWritten(Name, messageId); + } + + return; + } + + // Otherwise, check if the response was accepted (the response will come as an SSE message) if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) { LogAcceptedPost(Name, messageId); - // The response will arrive as an SSE message } else { - JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ?? - throw new InvalidOperationException("Failed to initialize client"); + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogRejectedPostSensitive(Name, messageId, responseContent); + } + else + { + LogRejectedPost(Name, messageId); + } - LogTransportReceivedMessage(Name, messageId); - await WriteMessageAsync(initializeResponse, cancellationToken).ConfigureAwait(false); - LogTransportMessageWritten(Name, messageId); + throw new InvalidOperationException("Failed to send message"); } - - return; } - - // Otherwise, check if the response was accepted (the response will come as an SSE message) - if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) + finally { - LogAcceptedPost(Name, messageId); - } - else - { - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogRejectedPostSensitive(Name, messageId, responseContent); - } - else - { - LogRejectedPost(Name, messageId); - } - - throw new InvalidOperationException("Failed to send message"); + response.Dispose(); } } @@ -187,13 +246,55 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Add authorization headers if needed + await _authorizationHandler.AuthenticateRequestAsync(request).ConfigureAwait(false); + + // Copy additional headers CopyAdditionalHeaders(request.Headers); - using var response = await _httpClient.SendAsync( - request, - HttpCompletionOption.ResponseHeadersRead, - cancellationToken - ).ConfigureAwait(false); + // Send the request, handling potential auth challenges + HttpResponseMessage? response = null; + bool authRetry = false; + + do + { + authRetry = false; + response = await _httpClient.SendAsync( + request, + HttpCompletionOption.ResponseHeadersRead, + cancellationToken + ).ConfigureAwait(false); + + // Handle 401 Unauthorized response + if (response.StatusCode == HttpStatusCode.Unauthorized) + { + // Try to handle the unauthorized response + authRetry = await _authorizationHandler.HandleUnauthorizedResponseAsync( + response, _sseEndpoint).ConfigureAwait(false); + + if (authRetry) + { + // Create a new request (we can't reuse the previous one) + using var newRequest = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); + newRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Add authorization headers for the new request + await _authorizationHandler.AuthenticateRequestAsync(newRequest).ConfigureAwait(false); + CopyAdditionalHeaders(newRequest.Headers); + + // Dispose the previous response + response.Dispose(); + + // Send the new request + response = await _httpClient.SendAsync( + newRequest, + HttpCompletionOption.ResponseHeadersRead, + cancellationToken + ).ConfigureAwait(false); + } + } + } while (authRetry); response.EnsureSuccessStatusCode(); diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 832d6727..fe8bfe2d 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Utils; namespace ModelContextProtocol.Protocol.Transport; @@ -7,10 +8,17 @@ namespace ModelContextProtocol.Protocol.Transport; /// Provides an over HTTP using the Server-Sent Events (SSE) protocol. /// /// +/// /// This transport connects to an MCP server over HTTP using SSE, /// allowing for real-time server-to-client communication with a standard HTTP request. /// Unlike the , this transport connects to an existing server /// rather than launching a new process. +/// +/// +/// The SSE transport can handle OAuth 2.0 authorization flows when connecting to servers that require authentication. +/// You can provide an in the transport options to handle the user authentication part +/// of the OAuth flow. +/// /// public sealed class SseClientTransport : IClientTransport, IAsyncDisposable { @@ -54,6 +62,115 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient /// public string Name { get; } + /// + /// Creates a delegate that can handle the OAuth 2.0 authorization code flow. + /// + /// A function that opens a URL in the browser. + /// The local port to listen on for the redirect URI. + /// The path for the redirect URI. + /// A delegate that can be used for the property. + /// + /// + /// This method creates a delegate that implements a complete local OAuth 2.0 authorization code flow. + /// When called, it will: + /// + /// + /// Open the authorization URL in the browser + /// Start a local HTTP server to listen for the authorization code + /// Return the redirect URI and authorization code when received + /// + /// + /// You can customize the port and path for the redirect URI. By default, it uses port 8888 and path "/callback". + /// + /// + public static Func> CreateLocalServerAuthorizeCallback( + Func openBrowser, + int listenPort = 8888, + string redirectPath = "/callback") + { + return async (ClientMetadata clientMetadata) => + { + var redirectUri = $"http://localhost:{listenPort}{redirectPath}"; + + // Use a TaskCompletionSource to wait for the authorization code + var authCodeTcs = new TaskCompletionSource(); + + // Start a local HTTP server to listen for the authorization code + using var listener = new System.Net.HttpListener(); + listener.Prefixes.Add($"http://localhost:{listenPort}/"); + listener.Start(); + + // Start listening for the callback asynchronously + var listenerTask = Task.Run(async () => + { + try + { + var context = await listener.GetContextAsync(); + var request = context.Request; + + // Get the authorization code from the query string + var code = request.QueryString["code"]; + var error = request.QueryString["error"]; + + // Send a response to the browser + var response = context.Response; + response.ContentType = "text/html"; + var responseHtml = "

Authorization Successful

You can now close this window and return to the application.

"; + + if (!string.IsNullOrEmpty(error)) + { + responseHtml = $"

Authorization Failed

Error: {error}

"; + authCodeTcs.SetException(new McpException($"Authorization failed: {error}", McpErrorCode.AuthenticationFailed)); + } + else if (string.IsNullOrEmpty(code)) + { + responseHtml = "

Authorization Failed

No authorization code received.

"; + authCodeTcs.SetException(new McpException("No authorization code received", McpErrorCode.AuthenticationFailed)); + } + else + { + authCodeTcs.SetResult(code); + } + + var buffer = System.Text.Encoding.UTF8.GetBytes(responseHtml); + response.ContentLength64 = buffer.Length; + await response.OutputStream.WriteAsync(buffer, 0, buffer.Length); + response.Close(); + } + catch (Exception ex) + { + authCodeTcs.TrySetException(ex); + } + finally + { + listener.Close(); + } + }); + + // Open the authorization URL in the browser + foreach (var uri in clientMetadata.RedirectUris) + { + if (uri.StartsWith("http://localhost")) + { + redirectUri = uri; + break; + } + } + + // We need to actually open the browser with the authorization URL + // Find the auth URL from client metadata and pass to openBrowser + if (clientMetadata.ClientUri != null) + { + await openBrowser(clientMetadata.ClientUri); + } + + // Wait for the authorization code + var code = await authCodeTcs.Task; + + return (redirectUri, code); + }; + } + /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs index 0a36a15f..0ee1c0fd 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs @@ -1,5 +1,7 @@ namespace ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Auth; + /// /// Provides options for configuring instances. /// @@ -55,4 +57,31 @@ public required Uri Endpoint /// Use this property to specify custom HTTP headers that should be sent with each request to the server. /// public Dictionary? AdditionalHeaders { get; init; } + + /// + /// Gets or sets a delegate that handles the OAuth 2.0 authorization code flow. + /// + /// + /// + /// This delegate is called when the SSE server requires OAuth 2.0 authorization. It receives the client metadata + /// and should return the redirect URI and authorization code received from the authorization server. + /// + /// + /// If not provided, the client will not be able to authenticate with servers that require OAuth authentication. + /// + /// + public Func>? AuthorizeCallback { get; init; } + + /// + /// Gets or sets a custom authorization handler. + /// + /// + /// + /// If specified, this handler will be used to manage authorization with the SSE server. + /// + /// + /// If not provided, a default handler will be created using the . + /// + /// + public IAuthorizationHandler? AuthorizationHandler { get; init; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Utils/SynchronizedValue.cs b/src/ModelContextProtocol/Utils/SynchronizedValue.cs new file mode 100644 index 00000000..106bf29b --- /dev/null +++ b/src/ModelContextProtocol/Utils/SynchronizedValue.cs @@ -0,0 +1,75 @@ +namespace ModelContextProtocol.Utils; + +/// +/// Provides a thread-safe synchronized value with locking functionality. +/// +/// The type of value to synchronize. +internal class SynchronizedValue where T : class +{ + private readonly SemaphoreSlim _semaphore = new(1, 1); + private T _value; + + /// + /// Initializes a new instance of the class. + /// + /// The initial value. + public SynchronizedValue(T initialValue) + { + _value = initialValue; + } + + /// + /// Gets the current value without locking. + /// + /// + /// This property should only be used when thread safety is not required. + /// + public T UnsafeValue => _value; + + /// + /// Acquires a lock on the value and provides access to it. + /// + /// A disposable that provides access to the value and releases the lock when disposed. + public async Task LockAsync() + { + await _semaphore.WaitAsync().ConfigureAwait(false); + return new SynchronizedValueHandle(this); + } + + /// + /// Provides a handle to access the synchronized value while holding a lock. + /// + public class SynchronizedValueHandle : IDisposable + { + private readonly SynchronizedValue _parent; + private bool _disposed; + + internal SynchronizedValueHandle(SynchronizedValue parent) + { + _parent = parent; + } + + /// + /// Gets or sets the synchronized value. + /// + public T Value + { + get => _parent._value; + set => _parent._value = value; + } + + /// + /// Releases the lock on the synchronized value. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + _parent._semaphore.Release(); + } + } +} \ No newline at end of file From 389fb3d028f6382d06ab579c8189d427c619887f Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:42:07 -0700 Subject: [PATCH 02/28] Fix JSON reference issues --- .../Protocol/Auth/AuthorizationService.cs | 107 ++++++++++++------ .../Protocol/Transport/SseClientTransport.cs | 2 +- .../Utils/Json/McpJsonUtilities.cs | 7 ++ 3 files changed, 78 insertions(+), 38 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs index e518dd4c..a6bc5807 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs @@ -4,6 +4,7 @@ using System.Text; using System.Text.Json; using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; namespace ModelContextProtocol.Protocol.Auth; @@ -65,12 +66,12 @@ internal class AuthorizationService using var metadataResponse = await s_httpClient.GetAsync(resourceMetadataUrl); metadataResponse.EnsureSuccessStatusCode(); - return await JsonSerializer.DeserializeAsync( - await metadataResponse.Content.ReadAsStreamAsync(), - new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true - }); + var contentStream = await metadataResponse.Content.ReadAsStreamAsync(); + + // Read as string first, then deserialize using source-generated serializer + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + return JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.ResourceMetadata); } catch (Exception) { @@ -102,12 +103,19 @@ public static async Task DiscoverAuthorizationServe using var openIdResponse = await s_httpClient.GetAsync(openIdConfigUrl); if (openIdResponse.IsSuccessStatusCode) { - return await JsonSerializer.DeserializeAsync( - await openIdResponse.Content.ReadAsStreamAsync(), - new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true - }) ?? throw new InvalidOperationException("Failed to parse authorization server metadata"); + var contentStream = await openIdResponse.Content.ReadAsStreamAsync(); + + // Use source-generated serialization instead of dynamic deserialization + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.AuthorizationServerMetadata); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse authorization server metadata"); + } + + return result; } } catch (Exception ex) when (ex is not InvalidOperationException) @@ -122,12 +130,19 @@ await openIdResponse.Content.ReadAsStreamAsync(), using var oauthResponse = await s_httpClient.GetAsync(oauthConfigUrl); if (oauthResponse.IsSuccessStatusCode) { - return await JsonSerializer.DeserializeAsync( - await oauthResponse.Content.ReadAsStreamAsync(), - new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true - }) ?? throw new InvalidOperationException("Failed to parse authorization server metadata"); + var contentStream = await oauthResponse.Content.ReadAsStreamAsync(); + + // Use source-generated serialization instead of dynamic deserialization + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.AuthorizationServerMetadata); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse authorization server metadata"); + } + + return result; } } catch (Exception ex) when (ex is not InvalidOperationException) @@ -160,19 +175,25 @@ public static async Task RegisterClientAsync( } var content = new StringContent( - JsonSerializer.Serialize(clientMetadata), + JsonSerializer.Serialize(clientMetadata, McpJsonUtilities.JsonContext.Default.ClientMetadata), Encoding.UTF8, "application/json"); using var response = await s_httpClient.PostAsync(metadata.RegistrationEndpoint, content); response.EnsureSuccessStatusCode(); - return await JsonSerializer.DeserializeAsync( - await response.Content.ReadAsStreamAsync(), - new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true - }) ?? throw new InvalidOperationException("Failed to parse client registration response"); + // Use source-generated serialization instead of dynamic deserialization + var contentStream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.ClientRegistrationResponse); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse client registration response"); + } + + return result; } /// @@ -293,12 +314,18 @@ public static async Task ExchangeCodeForTokensAsync( using var response = await s_httpClient.SendAsync(request); response.EnsureSuccessStatusCode(); - return await JsonSerializer.DeserializeAsync( - await response.Content.ReadAsStreamAsync(), - new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true - }) ?? throw new InvalidOperationException("Failed to parse token response"); + // Use source-generated serialization instead of dynamic deserialization + var contentStream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.TokenResponse); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse token response"); + } + + return result; } /// @@ -341,12 +368,18 @@ public static async Task RefreshTokenAsync( using var response = await s_httpClient.SendAsync(request); response.EnsureSuccessStatusCode(); - return await JsonSerializer.DeserializeAsync( - await response.Content.ReadAsStreamAsync(), - new JsonSerializerOptions - { - PropertyNameCaseInsensitive = true - }) ?? throw new InvalidOperationException("Failed to parse token response"); + // Use source-generated serialization instead of dynamic deserialization + var contentStream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.TokenResponse); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse token response"); + } + + return result; } private static Dictionary ParseAuthHeaderParameters(string parameters) diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index fe8bfe2d..6072875c 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -16,7 +16,7 @@ namespace ModelContextProtocol.Protocol.Transport; /// /// /// The SSE transport can handle OAuth 2.0 authorization flows when connecting to servers that require authentication. -/// You can provide an in the transport options to handle the user authentication part +/// You can provide an in the transport options to handle the user authentication part /// of the OAuth flow. /// /// diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index b759ba97..a8b2b996 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -122,6 +122,13 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(SubscribeRequestParams))] [JsonSerializable(typeof(UnsubscribeRequestParams))] [JsonSerializable(typeof(IReadOnlyDictionary))] + + // Authorization-related types + [JsonSerializable(typeof(Protocol.Auth.ResourceMetadata))] + [JsonSerializable(typeof(Protocol.Auth.AuthorizationServerMetadata))] + [JsonSerializable(typeof(Protocol.Auth.ClientMetadata))] + [JsonSerializable(typeof(Protocol.Auth.ClientRegistrationResponse))] + [JsonSerializable(typeof(Protocol.Auth.TokenResponse))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; From ddca6cc41fd71884b2c24b23b6b7e1aed167858b Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:00:01 -0700 Subject: [PATCH 03/28] Remove error codes - we don't use those --- .../McpAuthorizationException.cs | 18 +++++------------- src/ModelContextProtocol/McpErrorCode.cs | 16 ---------------- .../Auth/DefaultAuthorizationHandler.cs | 2 +- .../Protocol/Transport/SseClientTransport.cs | 4 ++-- 4 files changed, 8 insertions(+), 32 deletions(-) diff --git a/src/ModelContextProtocol/McpAuthorizationException.cs b/src/ModelContextProtocol/McpAuthorizationException.cs index 6dca9a2f..93eb4679 100644 --- a/src/ModelContextProtocol/McpAuthorizationException.cs +++ b/src/ModelContextProtocol/McpAuthorizationException.cs @@ -13,7 +13,7 @@ public class McpAuthorizationException : McpException /// Initializes a new instance of the class. /// public McpAuthorizationException() - : base("Authorization failed", McpErrorCode.Unauthorized) + : base("Authorization failed", McpErrorCode.InvalidRequest) { } @@ -22,7 +22,7 @@ public McpAuthorizationException() /// /// The message that describes the error. public McpAuthorizationException(string message) - : base(message, McpErrorCode.Unauthorized) + : base(message, McpErrorCode.InvalidRequest) { } @@ -32,7 +32,7 @@ public McpAuthorizationException(string message) /// The message that describes the error. /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. public McpAuthorizationException(string message, Exception? innerException) - : base(message, innerException, McpErrorCode.Unauthorized) + : base(message, innerException, McpErrorCode.InvalidRequest) { } @@ -40,14 +40,10 @@ public McpAuthorizationException(string message, Exception? innerException) /// Initializes a new instance of the class with a specified error message and error code. /// /// The message that describes the error. - /// The MCP error code. Should be either or . + /// The MCP error code. Should use one of the standard error codes. public McpAuthorizationException(string message, McpErrorCode errorCode) : base(message, errorCode) { - if (errorCode != McpErrorCode.Unauthorized && errorCode != McpErrorCode.AuthenticationFailed) - { - throw new ArgumentException($"Error code must be either {nameof(McpErrorCode.Unauthorized)} or {nameof(McpErrorCode.AuthenticationFailed)}", nameof(errorCode)); - } } /// @@ -55,14 +51,10 @@ public McpAuthorizationException(string message, McpErrorCode errorCode) /// /// The message that describes the error. /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. - /// The MCP error code. Should be either or . + /// The MCP error code. Should use one of the standard error codes. public McpAuthorizationException(string message, Exception? innerException, McpErrorCode errorCode) : base(message, innerException, errorCode) { - if (errorCode != McpErrorCode.Unauthorized && errorCode != McpErrorCode.AuthenticationFailed) - { - throw new ArgumentException($"Error code must be either {nameof(McpErrorCode.Unauthorized)} or {nameof(McpErrorCode.AuthenticationFailed)}", nameof(errorCode)); - } } /// diff --git a/src/ModelContextProtocol/McpErrorCode.cs b/src/ModelContextProtocol/McpErrorCode.cs index 69fcc741..f6cf4f51 100644 --- a/src/ModelContextProtocol/McpErrorCode.cs +++ b/src/ModelContextProtocol/McpErrorCode.cs @@ -46,20 +46,4 @@ public enum McpErrorCode /// This error is used when the endpoint encounters an unexpected condition that prevents it from fulfilling the request. /// InternalError = -32603, - - /// - /// Indicates that the client is not authorized to access the requested resource. - /// - /// - /// This error is returned when the client lacks the necessary credentials or permissions to access a resource. - /// - Unauthorized = -32401, - - /// - /// Indicates that the authentication process failed. - /// - /// - /// This error is returned when the client provides invalid or expired credentials, or when the authentication flow fails. - /// - AuthenticationFailed = -32402, } diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs index ca1de7b9..20f32922 100644 --- a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -193,7 +193,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp _logger.LogError(ex, "Failed to complete authorization flow"); var authException = new McpAuthorizationException( - $"Failed to complete authorization flow: {ex.Message}", ex, McpErrorCode.AuthenticationFailed); + $"Failed to complete authorization flow: {ex.Message}", ex, McpErrorCode.InvalidRequest); authException.ResourceUri = resourceMetadata.Resource; authException.AuthorizationServerUri = authServerUrl; diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 6072875c..e96dfd75 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -120,12 +120,12 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient if (!string.IsNullOrEmpty(error)) { responseHtml = $"

Authorization Failed

Error: {error}

"; - authCodeTcs.SetException(new McpException($"Authorization failed: {error}", McpErrorCode.AuthenticationFailed)); + authCodeTcs.SetException(new McpException($"Authorization failed: {error}", McpErrorCode.InvalidRequest)); } else if (string.IsNullOrEmpty(code)) { responseHtml = "

Authorization Failed

No authorization code received.

"; - authCodeTcs.SetException(new McpException("No authorization code received", McpErrorCode.AuthenticationFailed)); + authCodeTcs.SetException(new McpException("No authorization code received", McpErrorCode.InvalidRequest)); } else { From 99a417f9063142fda251b0710a0bc8f2a2538be5 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:17:06 -0700 Subject: [PATCH 04/28] Update sample to use proper SSE transport definition --- ModelContextProtocol.sln | 7 +++ samples/AuthorizationExample/Program.cs | 74 ++++++++----------------- 2 files changed, 29 insertions(+), 52 deletions(-) diff --git a/ModelContextProtocol.sln b/ModelContextProtocol.sln index 0e4fd721..950b71f2 100644 --- a/ModelContextProtocol.sln +++ b/ModelContextProtocol.sln @@ -56,6 +56,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ModelContextProtocol.AspNet EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ModelContextProtocol.AspNetCore.Tests", "tests\ModelContextProtocol.AspNetCore.Tests\ModelContextProtocol.AspNetCore.Tests.csproj", "{85557BA6-3D29-4C95-A646-2A972B1C2F25}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AuthorizationExample", "samples\AuthorizationExample\AuthorizationExample.csproj", "{C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -110,6 +112,10 @@ Global {85557BA6-3D29-4C95-A646-2A972B1C2F25}.Debug|Any CPU.Build.0 = Debug|Any CPU {85557BA6-3D29-4C95-A646-2A972B1C2F25}.Release|Any CPU.ActiveCfg = Release|Any CPU {85557BA6-3D29-4C95-A646-2A972B1C2F25}.Release|Any CPU.Build.0 = Release|Any CPU + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -128,6 +134,7 @@ Global {17B8453F-AB72-99C5-E5EA-D0B065A6AE65} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} {37B6A5E0-9995-497D-8B43-3BC6870CC716} = {A2F1F52A-9107-4BF8-8C3F-2F6670E7D0AD} {85557BA6-3D29-4C95-A646-2A972B1C2F25} = {2A77AF5C-138A-4EBB-9A13-9205DCD67928} + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {384A3888-751F-4D75-9AE5-587330582D89} diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 2594863a..4b0babb1 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -18,64 +18,34 @@ public static async Task Main(string[] args) var transportOptions = new SseClientTransportOptions { Endpoint = serverEndpoint, - - // Provide a callback to handle the authorization flow - AuthorizeCallback = async (clientMetadata) => - { - Console.WriteLine("Authentication required. Opening browser for authorization..."); - - // In a real app, you'd likely have a local HTTP server to receive the callback - // This is just a simplified example - Console.WriteLine("Once you've authorized in the browser, enter the code and redirect URI:"); - Console.Write("Code: "); - var code = Console.ReadLine() ?? ""; - Console.Write("Redirect URI: "); - var redirectUri = Console.ReadLine() ?? "http://localhost:8888/callback"; - - return (redirectUri, code); - } - - // Alternatively, use the built-in local server handler: - // AuthorizeCallback = SseClientTransport.CreateLocalServerAuthorizeCallback( - // openBrowser: async (url) => - // { - // // Open the URL in the user's default browser - // Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }); - // } - // ) + AuthorizeCallback = SseClientTransport.CreateLocalServerAuthorizeCallback( + openBrowser: async (url) => + { + // Open the URL in the user's default browser + Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }); + } + ) }; try { // Create the client with authorization-enabled transport var transport = new SseClientTransport(transportOptions); - var client = await McpClient.CreateAsync(transport); + var client = await McpClientFactory.CreateAsync(transport); - // Use the MCP client normally - authorization is handled automatically - // If the server returns a 401 Unauthorized response, the authorization flow will be triggered - var result = await client.PingAsync(); - Console.WriteLine($"Server ping successful: {result.ServerInfo.Name} {result.ServerInfo.Version}"); - - // Example tool call - var weatherPrompt = "What's the weather like today?"; - var weatherResult = await client.CompletionCompleteAsync( - new CompletionCompleteRequestBuilder(weatherPrompt).Build()); - - Console.WriteLine($"Response: {weatherResult.Content.Text}"); - } - catch (McpAuthorizationException authEx) - { - Console.WriteLine($"Authorization error: {authEx.Message}"); - Console.WriteLine($"Resource: {authEx.ResourceUri}"); - Console.WriteLine($"Auth server: {authEx.AuthorizationServerUri}"); - } - catch (McpException mcpEx) - { - Console.WriteLine($"MCP error: {mcpEx.Message} (Error code: {mcpEx.ErrorCode})"); - } - catch (Exception ex) - { - Console.WriteLine($"Unexpected error: {ex.Message}"); + // Print the list of tools available from the server. + foreach (var tool in await client.ListToolsAsync()) + { + Console.WriteLine($"{tool.Name} ({tool.Description})"); + } + + // Execute a tool (this would normally be driven by LLM tool invocations). + var result = await client.CallToolAsync( + "echo", + new Dictionary() { ["message"] = "Hello MCP!" }, + cancellationToken: CancellationToken.None); + + // echo always returns one and only one text content object + Console.WriteLine(result.Content.First(c => c.Type == "text").Text); } - } } \ No newline at end of file From 53c1151db7d150d312262faf8b6266241fc32eea Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 16:17:17 -0700 Subject: [PATCH 05/28] Stub for server implementation --- ModelContextProtocol.sln | 7 + samples/AuthorizationExample/Program.cs | 34 ++--- .../AuthorizationServerExample.csproj | 15 ++ samples/AuthorizationServerExample/Program.cs | 139 ++++++++++++++++++ .../Properties/launchSettings.json | 12 ++ .../McpAuthorizationExtensions.cs | 21 +++ .../McpAuthorizationMiddleware.cs | 113 ++++++++++++++ .../McpServerAuthorizationExtensions.cs | 38 +++++ .../Auth/IMcpServerAuthorizationProvider.cs | 26 ++++ .../Auth/ProtectedResourceMetadata.cs | 58 ++++++++ .../Protocol/Types/AuthorizationCapability.cs | 19 +++ .../Protocol/Types/ServerCapabilities.cs | 6 + .../Auth/SimpleServerAuthorizationProvider.cs | 54 +++++++ src/ModelContextProtocol/Server/McpServer.cs | 10 ++ .../Utils/Json/McpJsonUtilities.cs | 1 + 15 files changed, 535 insertions(+), 18 deletions(-) create mode 100644 samples/AuthorizationServerExample/AuthorizationServerExample.csproj create mode 100644 samples/AuthorizationServerExample/Program.cs create mode 100644 samples/AuthorizationServerExample/Properties/launchSettings.json create mode 100644 src/ModelContextProtocol.AspNetCore/McpAuthorizationExtensions.cs create mode 100644 src/ModelContextProtocol.AspNetCore/McpAuthorizationMiddleware.cs create mode 100644 src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/IMcpServerAuthorizationProvider.cs create mode 100644 src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs create mode 100644 src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs create mode 100644 src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs diff --git a/ModelContextProtocol.sln b/ModelContextProtocol.sln index 950b71f2..c033ea40 100644 --- a/ModelContextProtocol.sln +++ b/ModelContextProtocol.sln @@ -58,6 +58,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ModelContextProtocol.AspNet EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AuthorizationExample", "samples\AuthorizationExample\AuthorizationExample.csproj", "{C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AuthorizationServerExample", "samples\AuthorizationServerExample\AuthorizationServerExample.csproj", "{05C500AF-9CF6-C2E7-2782-95271975A5DE}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -116,6 +118,10 @@ Global {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Debug|Any CPU.Build.0 = Debug|Any CPU {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Release|Any CPU.ActiveCfg = Release|Any CPU {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Release|Any CPU.Build.0 = Release|Any CPU + {05C500AF-9CF6-C2E7-2782-95271975A5DE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {05C500AF-9CF6-C2E7-2782-95271975A5DE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {05C500AF-9CF6-C2E7-2782-95271975A5DE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {05C500AF-9CF6-C2E7-2782-95271975A5DE}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -135,6 +141,7 @@ Global {37B6A5E0-9995-497D-8B43-3BC6870CC716} = {A2F1F52A-9107-4BF8-8C3F-2F6670E7D0AD} {85557BA6-3D29-4C95-A646-2A972B1C2F25} = {2A77AF5C-138A-4EBB-9A13-9205DCD67928} {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} + {05C500AF-9CF6-C2E7-2782-95271975A5DE} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {384A3888-751F-4D75-9AE5-587330582D89} diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 4b0babb1..f5586d6b 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -27,25 +27,23 @@ public static async Task Main(string[] args) ) }; - try - { - // Create the client with authorization-enabled transport - var transport = new SseClientTransport(transportOptions); - var client = await McpClientFactory.CreateAsync(transport); + // Create the client with authorization-enabled transport + var transport = new SseClientTransport(transportOptions); + var client = await McpClientFactory.CreateAsync(transport); - // Print the list of tools available from the server. - foreach (var tool in await client.ListToolsAsync()) - { - Console.WriteLine($"{tool.Name} ({tool.Description})"); - } + // Print the list of tools available from the server. + foreach (var tool in await client.ListToolsAsync()) + { + Console.WriteLine($"{tool.Name} ({tool.Description})"); + } - // Execute a tool (this would normally be driven by LLM tool invocations). - var result = await client.CallToolAsync( - "echo", - new Dictionary() { ["message"] = "Hello MCP!" }, - cancellationToken: CancellationToken.None); + // Execute a tool (this would normally be driven by LLM tool invocations). + var result = await client.CallToolAsync( + "echo", + new Dictionary() { ["message"] = "Hello MCP!" }, + cancellationToken: CancellationToken.None); - // echo always returns one and only one text content object - Console.WriteLine(result.Content.First(c => c.Type == "text").Text); - } + // echo always returns one and only one text content object + Console.WriteLine(result.Content.First(c => c.Type == "text").Text); + } } \ No newline at end of file diff --git a/samples/AuthorizationServerExample/AuthorizationServerExample.csproj b/samples/AuthorizationServerExample/AuthorizationServerExample.csproj new file mode 100644 index 00000000..fad460dc --- /dev/null +++ b/samples/AuthorizationServerExample/AuthorizationServerExample.csproj @@ -0,0 +1,15 @@ + + + + Exe + net8.0 + enable + enable + + + + + + + + \ No newline at end of file diff --git a/samples/AuthorizationServerExample/Program.cs b/samples/AuthorizationServerExample/Program.cs new file mode 100644 index 00000000..f9ef6d60 --- /dev/null +++ b/samples/AuthorizationServerExample/Program.cs @@ -0,0 +1,139 @@ +using Microsoft.AspNetCore.Builder; +using ModelContextProtocol; +using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server.Auth; +using System.Text.Json; + +namespace AuthorizationServerExample; + +/// +/// Example demonstrating how to implement authorization in an MCP server. +/// +public class Program +{ + public static async Task Main(string[] args) + { + Console.WriteLine("=== MCP Server with Authorization Support ==="); + Console.WriteLine("This example demonstrates how to implement OAuth authorization in an MCP server."); + Console.WriteLine(); + + var builder = WebApplication.CreateBuilder(args); + + // 1. Define the Protected Resource Metadata for the server + // This is the information that will be provided to clients when they need to authenticate + var prm = new ProtectedResourceMetadata + { + Resource = "https://localhost:7071", // The resource identifier (typically your server's base URL) + AuthorizationServers = ["https://auth.example.com"], // Auth servers that can issue tokens for this resource + BearerMethodsSupported = ["header"], // We support the Authorization header + ScopesSupported = ["mcp.tools", "mcp.prompts", "mcp.resources"], // Scopes supported by this resource + ResourceDocumentation = "https://example.com/docs/mcp-server-auth" // Optional documentation URL + }; + + // 2. Define a token validator function + // This function receives the token from the Authorization header and should validate it + // In a real application, this would verify the token with your identity provider + async Task ValidateToken(string token) + { + // For demo purposes, we'll accept any token that starts with "valid_" + // In production, you would validate the token with your identity provider + var isValid = token.StartsWith("valid_", StringComparison.OrdinalIgnoreCase); + Console.WriteLine($"Token validation result: {(isValid ? "Valid" : "Invalid")}"); + return isValid; + } + + // 3. Create an authorization provider with the PRM and token validator + var authProvider = new SimpleServerAuthorizationProvider(prm, ValidateToken); + + // 4. Configure the MCP server with authorization + builder.Services.AddMcpServer(options => + { + options.ServerInstructions = "This is an MCP server with OAuth authorization enabled."; + + // Configure regular server capabilities like tools, prompts, resources + options.Capabilities = new() + { + Tools = new() + { + // Simple Echo tool + + CallToolHandler = (request, cancellationToken) => + { + if (request.Params?.Name == "echo") + { + if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) + { + throw new McpException("Missing required argument 'message'"); + } + + return new ValueTask(new CallToolResponse() + { + Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] + }); + } + + // Protected tool that requires authorization + if (request.Params?.Name == "protected-data") + { + // This tool will only be accessible to authenticated clients + return new ValueTask(new CallToolResponse() + { + Content = [new Content() { Text = "This is protected data that only authorized clients can access" }] + }); + } + + throw new McpException($"Unknown tool: '{request.Params?.Name}'"); + }, + + ListToolsHandler = async (_, _) => new() + { + Tools = + [ + new() + { + Name = "echo", + Description = "Echoes back the message you send" + }, + new() + { + Name = "protected-data", + Description = "Returns protected data that requires authorization" + } + ] + } + } + }; + }) + .WithAuthorization(authProvider) // Enable authorization with our provider + .WithHttpTransport(); // Configure HTTP transport + + var app = builder.Build(); + + // 5. Enable authorization middleware (this must be before MapMcp) + // This middleware does several things: + // - Serves the PRM document at /.well-known/oauth-protected-resource + // - Checks Authorization header on requests + // - Returns 401 + WWW-Authenticate when authorization is missing or invalid + app.UseMcpAuthorization(); + + // 6. Map MCP endpoints + app.MapMcp(); + + // Configure the server URL + app.Urls.Add("https://localhost:7071"); + + Console.WriteLine("Starting MCP server with authorization at https://localhost:7071"); + Console.WriteLine("PRM Document URL: https://localhost:7071/.well-known/oauth-protected-resource"); + Console.WriteLine(); + Console.WriteLine("To test the server:"); + Console.WriteLine("1. Use an MCP client that supports authorization"); + Console.WriteLine("2. When prompted for authorization, enter 'valid_token' to gain access"); + Console.WriteLine("3. Any other token value will be rejected with a 401 Unauthorized"); + Console.WriteLine(); + Console.WriteLine("Press Ctrl+C to stop the server"); + + await app.RunAsync(); + } +} \ No newline at end of file diff --git a/samples/AuthorizationServerExample/Properties/launchSettings.json b/samples/AuthorizationServerExample/Properties/launchSettings.json new file mode 100644 index 00000000..35989814 --- /dev/null +++ b/samples/AuthorizationServerExample/Properties/launchSettings.json @@ -0,0 +1,12 @@ +{ + "profiles": { + "AuthorizationServerExample": { + "commandName": "Project", + "launchBrowser": true, + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + }, + "applicationUrl": "https://localhost:50481;http://localhost:50482" + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpAuthorizationExtensions.cs new file mode 100644 index 00000000..b4f0a136 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpAuthorizationExtensions.cs @@ -0,0 +1,21 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Extension methods for using MCP authorization in ASP.NET Core applications. +/// +public static class McpAuthorizationExtensions +{ + /// + /// Adds MCP authorization middleware to the specified , which enables + /// OAuth 2.0 authorization for MCP servers. + /// + /// The to add the middleware to. + /// A reference to this instance after the operation has completed. + public static IApplicationBuilder UseMcpAuthorization(this IApplicationBuilder builder) + { + return builder.UseMiddleware(); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationMiddleware.cs b/src/ModelContextProtocol.AspNetCore/McpAuthorizationMiddleware.cs new file mode 100644 index 00000000..f128c2fc --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpAuthorizationMiddleware.cs @@ -0,0 +1,113 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Middleware that handles authorization for MCP servers. +/// +internal class McpAuthorizationMiddleware +{ + private readonly RequestDelegate _next; + private readonly ILogger _logger; + + /// + /// Initializes a new instance of the class. + /// + /// The next middleware in the pipeline. + /// The logger factory. + public McpAuthorizationMiddleware(RequestDelegate next, ILogger logger) + { + _next = next ?? throw new ArgumentNullException(nameof(next)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } + + /// + /// Processes a request. + /// + /// The HTTP context. + /// The MCP server options. + /// A representing the asynchronous operation. + public async Task InvokeAsync(HttpContext context, IOptions serverOptions) + { + // Check if authorization is configured + var authCapability = serverOptions.Value.Capabilities?.Authorization; + var authProvider = authCapability?.AuthorizationProvider; + + if (authProvider == null) + { + // Authorization is not configured, proceed to the next middleware + await _next(context); + return; + } + + // Handle the PRM document endpoint + if (context.Request.Path.StartsWithSegments("/.well-known/oauth-protected-resource")) + { + _logger.LogDebug("Serving Protected Resource Metadata document"); + context.Response.ContentType = "application/json"; + await JsonSerializer.SerializeAsync( + context.Response.Body, + authProvider.GetProtectedResourceMetadata(), + McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))); + return; + } + + // Serve SSE and message endpoints with authorization + if (context.Request.Path.StartsWithSegments("/sse") || + (context.Request.Path.Value?.EndsWith("/message") == true)) + { + // Check if the Authorization header is present + if (!context.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader)) + { + // No Authorization header present, return 401 Unauthorized + var prm = authProvider.GetProtectedResourceMetadata(); + var prmUrl = GetPrmUrl(context, prm.Resource); + + _logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header"); + context.Response.StatusCode = StatusCodes.Status401Unauthorized; + context.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); + return; + } + + // Validate the token - ensuring authHeader is a non-null string + string authHeaderValue = authHeader.ToString(); + bool isValid = await authProvider.ValidateTokenAsync(authHeaderValue); + if (!isValid) + { + // Invalid token, return 401 Unauthorized + var prm = authProvider.GetProtectedResourceMetadata(); + var prmUrl = GetPrmUrl(context, prm.Resource); + + _logger.LogDebug("Invalid authorization token, returning 401 Unauthorized"); + context.Response.StatusCode = StatusCodes.Status401Unauthorized; + context.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); + return; + } + } + + // Token is valid or endpoint doesn't require authentication, proceed to the next middleware + await _next(context); + } + + private static string GetPrmUrl(HttpContext context, string resourceUri) + { + // Use the actual resource URI from PRM if it's an absolute URL, otherwise build the URL + if (Uri.TryCreate(resourceUri, UriKind.Absolute, out _)) + { + return $"{resourceUri.TrimEnd('/')}/.well-known/oauth-protected-resource"; + } + + // Build the URL from the current request + var request = context.Request; + var scheme = request.Scheme; + var host = request.Host.Value; + return $"{scheme}://{host}/.well-known/oauth-protected-resource"; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs new file mode 100644 index 00000000..16ea0660 --- /dev/null +++ b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs @@ -0,0 +1,38 @@ +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods for configuring authorization in MCP servers. +/// +public static class McpServerAuthorizationExtensions +{ + /// + /// Adds authorization support to the MCP server. + /// + /// The to configure. + /// The authorization provider that will validate tokens and provide metadata. + /// The so that additional calls can be chained. + /// or is . + public static IMcpServerBuilder WithAuthorization( + this IMcpServerBuilder builder, + IMcpServerAuthorizationProvider authorizationProvider) + { + Throw.IfNull(builder); + Throw.IfNull(authorizationProvider); + + builder.Services.Configure(options => + { + options.Capabilities ??= new ServerCapabilities(); + options.Capabilities.Authorization = new AuthorizationCapability + { + AuthorizationProvider = authorizationProvider + }; + }); + + return builder; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/IMcpServerAuthorizationProvider.cs b/src/ModelContextProtocol/Protocol/Auth/IMcpServerAuthorizationProvider.cs new file mode 100644 index 00000000..53f4ca2e --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/IMcpServerAuthorizationProvider.cs @@ -0,0 +1,26 @@ +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Defines the interface for MCP server authorization providers. +/// +/// +/// This interface is implemented by authorization providers that enable MCP servers to validate tokens +/// and control access to protected resources. +/// +public interface IMcpServerAuthorizationProvider +{ + /// + /// Gets the Protected Resource Metadata (PRM) for the server. + /// + /// The protected resource metadata. + ProtectedResourceMetadata GetProtectedResourceMetadata(); + + /// + /// Validates the provided authorization token. + /// + /// The authorization header value. + /// A representing the asynchronous validation operation. The task result contains if the token is valid; otherwise, . + Task ValidateTokenAsync(string authorizationHeader); +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs new file mode 100644 index 00000000..7194b1b0 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs @@ -0,0 +1,58 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the Protected Resource Metadata (PRM) document for an OAuth 2.0 protected resource. +/// +/// +/// The PRM document describes the properties and requirements of a protected resource, including +/// the authorization servers that can be used to obtain access tokens and the scopes that are supported. +/// This document is served at the standard path "/.well-known/oauth-protected-resource" by MCP servers +/// that have authorization enabled. +/// +public class ProtectedResourceMetadata +{ + /// + /// Gets or sets the resource identifier URI. + /// + [JsonPropertyName("resource")] + public required string Resource { get; set; } + + /// + /// Gets or sets the authorization servers that can be used for authentication. + /// + [JsonPropertyName("authorization_servers")] + public required string[] AuthorizationServers { get; set; } + + /// + /// Gets or sets the bearer token methods supported by the resource. + /// + [JsonPropertyName("bearer_methods_supported")] + public string[]? BearerMethodsSupported { get; set; } = ["header"]; + + /// + /// Gets or sets the scopes supported by the resource. + /// + [JsonPropertyName("scopes_supported")] + public string[]? ScopesSupported { get; set; } + + /// + /// Gets or sets the URL to the resource documentation. + /// + [JsonPropertyName("resource_documentation")] + public string? ResourceDocumentation { get; set; } + + /// + /// Converts this to the internal type. + /// + /// A instance with the same values as this instance. + internal ResourceMetadata ToResourceMetadata() => new() + { + Resource = Resource, + AuthorizationServers = AuthorizationServers, + BearerMethodsSupported = BearerMethodsSupported, + ScopesSupported = ScopesSupported, + ResourceDocumentation = ResourceDocumentation + }; +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs b/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs new file mode 100644 index 00000000..b2361678 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs @@ -0,0 +1,19 @@ +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.Protocol.Types; + +/// +/// Defines the capabilities of a server for supporting OAuth 2.0 authorization. +/// +/// +/// This capability is advertised by servers that support OAuth 2.0 authorization flows +/// and require clients to authenticate using bearer tokens. +/// +public class AuthorizationCapability +{ + /// + /// Gets or sets the authorization provider that handles token validation and provides + /// metadata about the protected resource. + /// + public IMcpServerAuthorizationProvider? AuthorizationProvider { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs index 6406ea4d..8e524845 100644 --- a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs @@ -35,6 +35,12 @@ public class ServerCapabilities [JsonPropertyName("experimental")] public Dictionary? Experimental { get; set; } + /// + /// Gets or sets a server's authorization capability, supporting OAuth 2.0 authorization flows. + /// + [JsonPropertyName("authorization")] + public AuthorizationCapability? Authorization { get; set; } + /// /// Gets or sets a server's logging capability, supporting sending log messages to the client. /// diff --git a/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs b/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs new file mode 100644 index 00000000..925da2cd --- /dev/null +++ b/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs @@ -0,0 +1,54 @@ +using ModelContextProtocol.Protocol.Auth; +using System.Security.Claims; + +namespace ModelContextProtocol.Server.Auth; + +/// +/// A simple implementation of that validates bearer tokens. +/// +/// +/// This implementation is intended as a starting point for server developers. In production environments, +/// it should be extended or replaced with a more robust implementation that integrates with your +/// authentication system (e.g., OAuth 2.0 server, identity provider, etc.) +/// +public class SimpleServerAuthorizationProvider : IMcpServerAuthorizationProvider +{ + private readonly ProtectedResourceMetadata _resourceMetadata; + private readonly Func> _tokenValidator; + + /// + /// Initializes a new instance of the class + /// with the specified resource metadata and token validator. + /// + /// The protected resource metadata. + /// A function that validates access tokens. If not provided, a function that always returns true will be used. + public SimpleServerAuthorizationProvider( + ProtectedResourceMetadata resourceMetadata, + Func>? tokenValidator = null) + { + _resourceMetadata = resourceMetadata ?? throw new ArgumentNullException(nameof(resourceMetadata)); + _tokenValidator = tokenValidator ?? (_ => Task.FromResult(true)); + } + + /// + public ProtectedResourceMetadata GetProtectedResourceMetadata() => _resourceMetadata; + + /// + public async Task ValidateTokenAsync(string authorizationHeader) + { + // Extract the token from the Authorization header + if (string.IsNullOrEmpty(authorizationHeader) || !authorizationHeader.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + var token = authorizationHeader["Bearer ".Length..].Trim(); + if (string.IsNullOrEmpty(token)) + { + return false; + } + + // Validate the token + return await _tokenValidator(token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index ae0e7afc..aeee62dd 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -66,6 +66,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? SetResourcesHandler(options); SetSetLoggingLevelHandler(options); SetCompletionHandler(options); + SetAuthorizationHandler(); SetPingHandler(); // Register any notification handlers that were provided. @@ -327,6 +328,7 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals ServerCapabilities = new() { Experimental = options.Capabilities?.Experimental, + Authorization = options.Capabilities?.Authorization, Logging = options.Capabilities?.Logging, Tools = options.Capabilities?.Tools, Resources = options.Capabilities?.Resources, @@ -425,6 +427,7 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) ServerCapabilities = new() { Experimental = options.Capabilities?.Experimental, + Authorization = options.Capabilities?.Authorization, Logging = options.Capabilities?.Logging, Prompts = options.Capabilities?.Prompts, Resources = options.Capabilities?.Resources, @@ -503,6 +506,13 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) McpJsonUtilities.JsonContext.Default.EmptyResult); } + private void SetAuthorizationHandler() + { + // The authorization capability is handled via middleware in ASP.NET Core, + // so we don't need to set up any special handlers here. + // We just make sure to include the capability in the ServerCapabilities. + } + private ValueTask InvokeHandlerAsync( Func, CancellationToken, ValueTask> handler, TParams? args, diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index a8b2b996..169b27e3 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -125,6 +125,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element) // Authorization-related types [JsonSerializable(typeof(Protocol.Auth.ResourceMetadata))] + [JsonSerializable(typeof(Protocol.Auth.ProtectedResourceMetadata))] [JsonSerializable(typeof(Protocol.Auth.AuthorizationServerMetadata))] [JsonSerializable(typeof(Protocol.Auth.ClientMetadata))] [JsonSerializable(typeof(Protocol.Auth.ClientRegistrationResponse))] From 3e9462c652d977174668f90035333bacdb28a0c7 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 16:26:29 -0700 Subject: [PATCH 06/28] HTTP for local testing --- samples/AuthorizationServerExample/Program.cs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/samples/AuthorizationServerExample/Program.cs b/samples/AuthorizationServerExample/Program.cs index f9ef6d60..61cfc00a 100644 --- a/samples/AuthorizationServerExample/Program.cs +++ b/samples/AuthorizationServerExample/Program.cs @@ -1,10 +1,8 @@ -using Microsoft.AspNetCore.Builder; using ModelContextProtocol; using ModelContextProtocol.AspNetCore; using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server.Auth; -using System.Text.Json; namespace AuthorizationServerExample; @@ -25,7 +23,7 @@ public static async Task Main(string[] args) // This is the information that will be provided to clients when they need to authenticate var prm = new ProtectedResourceMetadata { - Resource = "https://localhost:7071", // The resource identifier (typically your server's base URL) + Resource = "http://localhost:7071", // Changed from HTTPS to HTTP for local development AuthorizationServers = ["https://auth.example.com"], // Auth servers that can issue tokens for this resource BearerMethodsSupported = ["header"], // We support the Authorization header ScopesSupported = ["mcp.tools", "mcp.prompts", "mcp.resources"], // Scopes supported by this resource @@ -122,10 +120,11 @@ async Task ValidateToken(string token) app.MapMcp(); // Configure the server URL - app.Urls.Add("https://localhost:7071"); + app.Urls.Add("http://localhost:7071"); + + Console.WriteLine("Starting MCP server with authorization at http://localhost:7071"); + Console.WriteLine("PRM Document URL: http://localhost:7071/.well-known/oauth-protected-resource"); - Console.WriteLine("Starting MCP server with authorization at https://localhost:7071"); - Console.WriteLine("PRM Document URL: https://localhost:7071/.well-known/oauth-protected-resource"); Console.WriteLine(); Console.WriteLine("To test the server:"); Console.WriteLine("1. Use an MCP client that supports authorization"); From ecc40ab879802fef686787ec7d1d122ef8708b8d Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 16:57:05 -0700 Subject: [PATCH 07/28] Tinkering with test logic --- samples/AuthorizationExample/Program.cs | 3 +-- samples/AuthorizationServerExample/Program.cs | 2 +- .../Protocol/Auth/AuthorizationContext.cs | 10 +++++++++- .../Protocol/Auth/DefaultAuthorizationHandler.cs | 8 ++++---- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index f5586d6b..1278b880 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -12,7 +12,7 @@ public class Program public static async Task Main(string[] args) { // Define the MCP server endpoint that requires OAuth authentication - var serverEndpoint = new Uri("https://example.com/mcp"); + var serverEndpoint = new Uri("http://localhost:7071/sse"); // Set up the SSE transport with authorization support var transportOptions = new SseClientTransportOptions @@ -21,7 +21,6 @@ public static async Task Main(string[] args) AuthorizeCallback = SseClientTransport.CreateLocalServerAuthorizeCallback( openBrowser: async (url) => { - // Open the URL in the user's default browser Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }); } ) diff --git a/samples/AuthorizationServerExample/Program.cs b/samples/AuthorizationServerExample/Program.cs index 61cfc00a..97dfe309 100644 --- a/samples/AuthorizationServerExample/Program.cs +++ b/samples/AuthorizationServerExample/Program.cs @@ -24,7 +24,7 @@ public static async Task Main(string[] args) var prm = new ProtectedResourceMetadata { Resource = "http://localhost:7071", // Changed from HTTPS to HTTP for local development - AuthorizationServers = ["https://auth.example.com"], // Auth servers that can issue tokens for this resource + AuthorizationServers = ["https://login.microsoftonline.com/a2213e1c-e51e-4304-9a0d-effe57f31655/v2.0"], // Let's use a dummy Entra ID tenant here BearerMethodsSupported = ["header"], // We support the Authorization header ScopesSupported = ["mcp.tools", "mcp.prompts", "mcp.resources"], // Scopes supported by this resource ResourceDocumentation = "https://example.com/docs/mcp-server-auth" // Optional documentation URL diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs index e65dd4ff..c210fd9e 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs @@ -85,7 +85,15 @@ public bool ValidateResourceUrl(string resourceUrl) return false; } - // Resource URL must match exactly + // Compare the host part (FQDN) rather than the full URL + if (Uri.TryCreate(resourceUrl, UriKind.Absolute, out Uri? resourceUri) && + Uri.TryCreate(ResourceMetadata.Resource, UriKind.Absolute, out Uri? metadataUri)) + { + // Compare only the host (domain name) + return string.Equals(resourceUri.Host, metadataUri.Host, StringComparison.OrdinalIgnoreCase); + } + + // If we can't parse both URLs, fall back to exact string comparison return string.Equals(resourceUrl, ResourceMetadata.Resource, StringComparison.OrdinalIgnoreCase); } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs index 20f32922..711c37bb 100644 --- a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -75,9 +75,11 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp throw exception; } + // Store the resource metadata in the context before validating the resource URL + authContext.Value.ResourceMetadata = resourceMetadata; + // Validate that the resource matches the server FQDN - if (!authContext.Value.ValidateResourceUrl(serverUri.ToString()) && - !string.Equals(resourceMetadata.Resource, serverUri.ToString(), StringComparison.OrdinalIgnoreCase)) + if (!authContext.Value.ValidateResourceUrl(serverUri.ToString())) { _logger.LogWarning("Resource URL mismatch: expected {Expected}, got {Actual}", serverUri, resourceMetadata.Resource); @@ -87,8 +89,6 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp throw exception; } - authContext.Value.ResourceMetadata = resourceMetadata; - // Get the first authorization server from the metadata if (resourceMetadata.AuthorizationServers == null || resourceMetadata.AuthorizationServers.Length == 0) { From 4c7a578398e4cd1606894f8ee57c4ec5852693b3 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 19:57:27 -0700 Subject: [PATCH 08/28] Iterating on the changes --- samples/AuthorizationExample/Program.cs | 84 ++++++++++---- .../Auth/DefaultAuthorizationHandler.cs | 73 +++++++++--- .../Protocol/Auth/McpAuthorizationOptions.cs | 81 +++++++++++++ .../Transport/SseClientSessionTransport.cs | 12 +- .../Protocol/Transport/SseClientTransport.cs | 109 +++++++++++++----- .../Transport/SseClientTransportOptions.cs | 37 +++--- 6 files changed, 312 insertions(+), 84 deletions(-) create mode 100644 src/ModelContextProtocol/Protocol/Auth/McpAuthorizationOptions.cs diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 1278b880..56ad3c06 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -1,5 +1,6 @@ using System.Diagnostics; using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Protocol.Transport; namespace AuthorizationExample; @@ -14,35 +15,76 @@ public static async Task Main(string[] args) // Define the MCP server endpoint that requires OAuth authentication var serverEndpoint = new Uri("http://localhost:7071/sse"); + // Configuration values for OAuth redirect + string hostname = "localhost"; + int port = 8888; + string callbackPath = "/oauth/callback"; + // Set up the SSE transport with authorization support var transportOptions = new SseClientTransportOptions { Endpoint = serverEndpoint, - AuthorizeCallback = SseClientTransport.CreateLocalServerAuthorizeCallback( - openBrowser: async (url) => - { - Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }); - } - ) + AuthorizationOptions = new McpAuthorizationOptions + { + // Pre-registered client credentials (if applicable) + ClientId = "my-registered-client-id", + ClientSecret = "optional-client-secret", + + // Specify the exact same redirect URIs that are registered with the OAuth server + RedirectUris = new[] + { + $"http://{hostname}:{port}{callbackPath}" + }, + + // Configure the authorize callback with the same hostname, port, and path + AuthorizeCallback = SseClientTransport.CreateHttpListenerAuthorizeCallback( + openBrowser: async (url) => + { + Console.WriteLine($"Opening browser to authorize at: {url}"); + Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }); + }, + hostname: hostname, + listenPort: port, + redirectPath: callbackPath, + successHtml: "

Authorization Successful

You have successfully authorized the application. You can close this window and return to the app.

" + ) + } }; - // Create the client with authorization-enabled transport - var transport = new SseClientTransport(transportOptions); - var client = await McpClientFactory.CreateAsync(transport); - - // Print the list of tools available from the server. - foreach (var tool in await client.ListToolsAsync()) + Console.WriteLine("Connecting to MCP server..."); + + try { - Console.WriteLine($"{tool.Name} ({tool.Description})"); - } + // Create the client with authorization-enabled transport + var transport = new SseClientTransport(transportOptions); + var client = await McpClientFactory.CreateAsync(transport); - // Execute a tool (this would normally be driven by LLM tool invocations). - var result = await client.CallToolAsync( - "echo", - new Dictionary() { ["message"] = "Hello MCP!" }, - cancellationToken: CancellationToken.None); + Console.WriteLine("Successfully connected and authorized!"); + + // Print the list of tools available from the server. + Console.WriteLine("\nAvailable tools:"); + foreach (var tool in await client.ListToolsAsync()) + { + Console.WriteLine($" - {tool.Name}: {tool.Description}"); + } - // echo always returns one and only one text content object - Console.WriteLine(result.Content.First(c => c.Type == "text").Text); + // Execute a tool (this would normally be driven by LLM tool invocations). + Console.WriteLine("\nCalling 'echo' tool..."); + var result = await client.CallToolAsync( + "echo", + new Dictionary() { ["message"] = "Hello MCP!" }, + cancellationToken: CancellationToken.None); + + // echo always returns one and only one text content object + Console.WriteLine($"Tool response: {result.Content.First(c => c.Type == "text").Text}"); + } + catch (Exception ex) + { + Console.WriteLine($"Error: {ex.Message}"); + if (ex.InnerException != null) + { + Console.WriteLine($"Inner Error: {ex.InnerException.Message}"); + } + } } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs index 711c37bb..ce240e8c 100644 --- a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -15,18 +15,40 @@ internal class DefaultAuthorizationHandler : IAuthorizationHandler private readonly ILogger _logger; private readonly SynchronizedValue _authContext = new(new AuthorizationContext()); private readonly Func>? _authorizeCallback; + private readonly string? _clientId; + private readonly string? _clientSecret; + private readonly ICollection? _redirectUris; + private readonly ICollection? _scopes; /// /// Initializes a new instance of the class. /// /// The logger factory. - /// A callback function that handles the authorization code flow. - public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, Func>? authorizeCallback = null) + /// The authorization options. + public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, McpAuthorizationOptions? options = null) { _logger = loggerFactory != null ? loggerFactory.CreateLogger() : NullLogger.Instance; - _authorizeCallback = authorizeCallback; + + if (options != null) + { + _authorizeCallback = options.AuthorizeCallback; + _clientId = options.ClientId; + _clientSecret = options.ClientSecret; + _redirectUris = options.RedirectUris; + _scopes = options.Scopes; + } + } + + /// + /// Initializes a new instance of the class. + /// + /// The logger factory. + /// A callback function that handles the authorization code flow. + public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, Func>? authorizeCallback = null) + : this(loggerFactory, new McpAuthorizationOptions { AuthorizeCallback = authorizeCallback }) + { } /// @@ -110,16 +132,31 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp _logger.LogDebug("Successfully retrieved authorization server metadata"); // Create client metadata + string[] redirectUris = _redirectUris?.ToArray() ?? new[] { "http://localhost:8888/callback" }; var clientMetadata = new ClientMetadata { - RedirectUris = new[] { "http://localhost:8888/callback" }, // Default redirect URI + RedirectUris = redirectUris, ClientName = "MCP C# SDK Client", - Scope = string.Join(" ", resourceMetadata.ScopesSupported ?? Array.Empty()) + Scope = string.Join(" ", _scopes ?? resourceMetadata.ScopesSupported ?? Array.Empty()) }; - - // Register client if the server supports it - if (authServerMetadata.RegistrationEndpoint != null) + + // Register client if needed, or use pre-configured client ID + if (!string.IsNullOrEmpty(_clientId)) + { + _logger.LogDebug("Using pre-configured client ID: {ClientId}", _clientId); + + // Create a client registration response to store in the context + var clientRegistration = new ClientRegistrationResponse + { + ClientId = _clientId!, // Using null-forgiving operator since we've already checked it's not null + ClientSecret = _clientSecret, + }; + + authContext.Value.ClientRegistration = clientRegistration; + } + else if (authServerMetadata.RegistrationEndpoint != null) { + // Register client dynamically _logger.LogDebug("Registering client with authorization server"); var clientRegistration = await AuthorizationService.RegisterClientAsync(authServerMetadata, clientMetadata); authContext.Value.ClientRegistration = clientRegistration; @@ -127,9 +164,11 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp } else { - _logger.LogWarning("Authorization server does not support dynamic client registration"); + _logger.LogWarning("Authorization server does not support dynamic client registration and no client ID was provided"); - var exception = new McpAuthorizationException("Authorization server does not support dynamic client registration"); + var exception = new McpAuthorizationException( + "Authorization server does not support dynamic client registration and no client ID was provided. " + + "Use McpAuthorizationOptions.ClientId to provide a pre-registered client ID."); exception.ResourceUri = resourceMetadata.Resource; exception.AuthorizationServerUri = authServerUrl; throw exception; @@ -142,7 +181,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp var exception = new McpAuthorizationException( "Authentication is required but no authorization callback was provided. " + - "Use SseClientTransportOptions.AuthorizeCallback to provide a callback function."); + "Use McpAuthorizationOptions.AuthorizeCallback to provide a callback function."); exception.ResourceUri = resourceMetadata.Resource; exception.AuthorizationServerUri = authServerUrl; throw exception; @@ -155,18 +194,18 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp // Initiate authorization code flow _logger.LogDebug("Initiating authorization code flow"); - // Get the registered client ID - var clientId = authContext.Value.ClientRegistration!.ClientId; - // Get the authorization URL that the user needs to visit var authUrl = AuthorizationService.CreateAuthorizationUrl( authServerMetadata, - clientId, + authContext.Value.ClientRegistration.ClientId, clientMetadata.RedirectUris[0], codeChallenge, - resourceMetadata.ScopesSupported); + _scopes?.ToArray() ?? resourceMetadata.ScopesSupported); _logger.LogDebug("Authorization URL: {AuthUrl}", authUrl); + + // Set the authorization URL in the client metadata + clientMetadata.ClientUri = authUrl; // Let the callback handle the user authorization var (redirectUri, code) = await _authorizeCallback(clientMetadata); @@ -176,7 +215,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp _logger.LogDebug("Exchanging authorization code for tokens"); var tokenResponse = await AuthorizationService.ExchangeCodeForTokensAsync( authServerMetadata, - clientId, + authContext.Value.ClientRegistration.ClientId, authContext.Value.ClientRegistration.ClientSecret, redirectUri, code, diff --git a/src/ModelContextProtocol/Protocol/Auth/McpAuthorizationOptions.cs b/src/ModelContextProtocol/Protocol/Auth/McpAuthorizationOptions.cs new file mode 100644 index 00000000..85af0a94 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/McpAuthorizationOptions.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Provides authorization options for MCP clients. +/// +public class McpAuthorizationOptions +{ + /// + /// Gets or sets a delegate that handles the OAuth 2.0 authorization code flow. + /// + /// + /// + /// This delegate is called when the server requires OAuth 2.0 authorization. It receives the client metadata + /// and should return the redirect URI and authorization code received from the authorization server. + /// + /// + /// If not provided, the client will not be able to authenticate with servers that require OAuth authentication. + /// + /// + public Func>? AuthorizeCallback { get; init; } + + /// + /// Gets or sets the client ID to use for OAuth authorization. + /// + /// + /// + /// If specified, this client ID will be used during the OAuth flow instead of performing dynamic client registration. + /// This is useful when connecting to servers that have pre-registered clients. + /// + /// + public string? ClientId { get; init; } + + /// + /// Gets or sets the client secret associated with the client ID. + /// + /// + /// This is only required if the client was registered as a confidential client with the authorization server. + /// Public clients don't require a client secret. + /// + public string? ClientSecret { get; init; } + + /// + /// Gets or sets the redirect URIs that can be used during the OAuth authorization flow. + /// + /// + /// + /// These URIs must match the redirect URIs registered with the authorization server for the client. + /// + /// + /// If not specified and is set, a default value of + /// "http://localhost:8888/callback" will be used. + /// + /// + public ICollection? RedirectUris { get; init; } + + /// + /// Gets or sets the scopes to request during OAuth authorization. + /// + /// + /// + /// If not specified, the scopes will be determined from the server's resource metadata. + /// + /// + public ICollection? Scopes { get; init; } + + /// + /// Gets or sets a custom authorization handler. + /// + /// + /// + /// If specified, this handler will be used to manage authorization with the server. + /// + /// + /// If not provided, a default handler will be created using the other options. + /// + /// + public IAuthorizationHandler? AuthorizationHandler { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index f77b49fa..2f5ad8f3 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -50,8 +50,16 @@ public SseClientSessionTransport(SseClientTransportOptions transportOptions, Htt _connectionEstablished = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); // Initialize the authorization handler - _authorizationHandler = transportOptions.AuthorizationHandler ?? - new DefaultAuthorizationHandler(loggerFactory, transportOptions.AuthorizeCallback); + if (transportOptions.AuthorizationOptions?.AuthorizationHandler != null) + { + // Use explicitly provided handler + _authorizationHandler = transportOptions.AuthorizationOptions.AuthorizationHandler; + } + else + { + // Create default handler with auth options + _authorizationHandler = new DefaultAuthorizationHandler(loggerFactory, transportOptions.AuthorizationOptions); + } } /// diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index e96dfd75..c4d6c9de 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -16,8 +16,6 @@ namespace ModelContextProtocol.Protocol.Transport; /// /// /// The SSE transport can handle OAuth 2.0 authorization flows when connecting to servers that require authentication. -/// You can provide an in the transport options to handle the user authentication part -/// of the OAuth flow. /// /// public sealed class SseClientTransport : IClientTransport, IAsyncDisposable @@ -63,43 +61,85 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient public string Name { get; } /// - /// Creates a delegate that can handle the OAuth 2.0 authorization code flow. + /// Creates a delegate that can handle the OAuth 2.0 authorization code flow using an HTTP listener. /// /// A function that opens a URL in the browser. - /// The local port to listen on for the redirect URI. - /// The path for the redirect URI. - /// A delegate that can be used for the property. + /// The hostname to listen on for the redirect URI. Default is "localhost". + /// The port to listen on for the redirect URI. Default is 8888. + /// The path for the redirect URI. Default is "/callback". + /// The HTML content to display on successful authorization. If null, a default message is shown. + /// The HTML template to display on failed authorization. If null, a default message is shown. Use {0} as a placeholder for the error message. + /// A delegate that can be used for the property. /// /// - /// This method creates a delegate that implements a complete local OAuth 2.0 authorization code flow. + /// This method creates a delegate that implements a complete OAuth 2.0 authorization code flow using an HTTP listener. /// When called, it will: /// /// /// Open the authorization URL in the browser - /// Start a local HTTP server to listen for the authorization code + /// Start an HTTP listener to receive the authorization code /// Return the redirect URI and authorization code when received /// /// - /// You can customize the port and path for the redirect URI. By default, it uses port 8888 and path "/callback". + /// You can customize the hostname, port, and path for the redirect URI to match your OAuth client configuration. /// /// - public static Func> CreateLocalServerAuthorizeCallback( + public static Func> CreateHttpListenerAuthorizeCallback( Func openBrowser, + string hostname = "localhost", int listenPort = 8888, - string redirectPath = "/callback") + string redirectPath = "/callback", + string? successHtml = null, + string? errorHtml = null) { return async (ClientMetadata clientMetadata) => { - var redirectUri = $"http://localhost:{listenPort}{redirectPath}"; + // Default redirect URI based on parameters + var defaultRedirectUri = $"http://{hostname}:{listenPort}{redirectPath}"; + + // First, try to find a matching redirect URI from the client metadata + var redirectUri = defaultRedirectUri; + var hostPrefix = $"http://{hostname}"; + + foreach (var uri in clientMetadata.RedirectUris) + { + if (uri.StartsWith(hostPrefix, StringComparison.OrdinalIgnoreCase)) + { + redirectUri = uri; + + // Parse the port and path from the selected URI to ensure we listen on the correct endpoint + if (Uri.TryCreate(uri, UriKind.Absolute, out var parsedUri)) + { + listenPort = parsedUri.IsDefaultPort ? 80 : parsedUri.Port; + redirectPath = parsedUri.AbsolutePath; + } + + break; + } + } // Use a TaskCompletionSource to wait for the authorization code var authCodeTcs = new TaskCompletionSource(); - // Start a local HTTP server to listen for the authorization code + // Start an HTTP listener to listen for the authorization code using var listener = new System.Net.HttpListener(); - listener.Prefixes.Add($"http://localhost:{listenPort}/"); + + // Ensure the URI format is correct for HttpListener + var listenerPrefix = $"http://{hostname}:{listenPort}/"; + if (redirectPath.Length > 1) + { + // If path is something like "/callback", we need to listen on all paths that start with it + var basePath = redirectPath.TrimEnd('/').TrimStart('/'); + listenerPrefix = $"http://{hostname}:{listenPort}/{basePath}/"; + } + + listener.Prefixes.Add(listenerPrefix); listener.Start(); + // Default HTML responses + var defaultSuccessHtml = "

Authorization Successful

You can now close this window and return to the application.

"; + var defaultErrorHtml = "

Authorization Failed

Error: {0}

"; + // Start listening for the callback asynchronously var listenerTask = Task.Run(async () => { @@ -115,20 +155,21 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient // Send a response to the browser var response = context.Response; response.ContentType = "text/html"; - var responseHtml = "

Authorization Successful

You can now close this window and return to the application.

"; + string responseHtml; if (!string.IsNullOrEmpty(error)) { - responseHtml = $"

Authorization Failed

Error: {error}

"; + responseHtml = string.Format(errorHtml ?? defaultErrorHtml, error); authCodeTcs.SetException(new McpException($"Authorization failed: {error}", McpErrorCode.InvalidRequest)); } else if (string.IsNullOrEmpty(code)) { - responseHtml = "

Authorization Failed

No authorization code received.

"; + responseHtml = string.Format(errorHtml ?? defaultErrorHtml, "No authorization code received"); authCodeTcs.SetException(new McpException("No authorization code received", McpErrorCode.InvalidRequest)); } else { + responseHtml = successHtml ?? defaultSuccessHtml; authCodeTcs.SetResult(code); } @@ -148,21 +189,14 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient }); // Open the authorization URL in the browser - foreach (var uri in clientMetadata.RedirectUris) - { - if (uri.StartsWith("http://localhost")) - { - redirectUri = uri; - break; - } - } - - // We need to actually open the browser with the authorization URL - // Find the auth URL from client metadata and pass to openBrowser if (clientMetadata.ClientUri != null) { await openBrowser(clientMetadata.ClientUri); } + else + { + authCodeTcs.SetException(new McpException("No authorization URL provided in client metadata", McpErrorCode.InvalidRequest)); + } // Wait for the authorization code var code = await authCodeTcs.Task; @@ -171,6 +205,25 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient }; } + /// + /// Creates a delegate that can handle the OAuth 2.0 authorization code flow using a local HTTP listener. + /// + /// A function that opens a URL in the browser. + /// The local port to listen on for the redirect URI. + /// The path for the redirect URI. + /// A delegate that can be used for the property. + /// + /// This is a convenience method that calls with "localhost" as the hostname. + /// + [Obsolete("Use CreateHttpListenerAuthorizeCallback instead. This method will be removed in a future version.")] + public static Func> CreateLocalServerAuthorizeCallback( + Func openBrowser, + int listenPort = 8888, + string redirectPath = "/callback") + { + return CreateHttpListenerAuthorizeCallback(openBrowser, "localhost", listenPort, redirectPath); + } + /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs index 0ee1c0fd..28420bb0 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs @@ -59,29 +59,34 @@ public required Uri Endpoint public Dictionary? AdditionalHeaders { get; init; } /// - /// Gets or sets a delegate that handles the OAuth 2.0 authorization code flow. + /// Gets or sets the authorization options to use when connecting to the SSE server. /// /// /// - /// This delegate is called when the SSE server requires OAuth 2.0 authorization. It receives the client metadata - /// and should return the redirect URI and authorization code received from the authorization server. + /// These options configure the behavior of client-side authorization with the SSE server. /// /// - /// If not provided, the client will not be able to authenticate with servers that require OAuth authentication. - /// - /// - public Func>? AuthorizeCallback { get; init; } - - /// - /// Gets or sets a custom authorization handler. - /// - /// - /// - /// If specified, this handler will be used to manage authorization with the SSE server. + /// You can use this to specify a callback for handling the authorization code flow, + /// provide pre-registered client credentials, or configure other aspects of the OAuth flow. /// /// - /// If not provided, a default handler will be created using the . + /// Example: + /// + /// var transportOptions = new SseClientTransportOptions + /// { + /// Endpoint = new Uri("http://localhost:7071/sse"), + /// AuthorizationOptions = new McpAuthorizationOptions + /// { + /// ClientId = "my-client-id", + /// ClientSecret = "my-client-secret", + /// RedirectUris = new[] { "http://localhost:8888/callback" }, + /// AuthorizeCallback = SseClientTransport.CreateHttpListenerAuthorizeCallback( + /// openBrowser: url => Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }) + /// ) + /// } + /// }; + /// /// /// - public IAuthorizationHandler? AuthorizationHandler { get; init; } + public McpAuthorizationOptions? AuthorizationOptions { get; init; } } \ No newline at end of file From d339973330b7b208a35ca2812fb87a8c70cb7348 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 20:03:27 -0700 Subject: [PATCH 09/28] Testing client configuration --- samples/AuthorizationExample/Program.cs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 56ad3c06..1bc2cc56 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -27,9 +27,11 @@ public static async Task Main(string[] args) AuthorizationOptions = new McpAuthorizationOptions { // Pre-registered client credentials (if applicable) - ClientId = "my-registered-client-id", - ClientSecret = "optional-client-secret", - + ClientId = "04f79824-ab56-4511-a7cb-d7deaea92dc0", + + // Setting some pre-defined scopes the client requests. + Scopes = ["User.Read"], + // Specify the exact same redirect URIs that are registered with the OAuth server RedirectUris = new[] { From bdee0e346de2f3dc82bc22d93af63c998e559253 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 21:44:03 -0700 Subject: [PATCH 10/28] Update to make sure naming is consistent --- samples/AuthorizationExample/Program.cs | 2 +- ...tensions.cs => AuthorizationExtensions.cs} | 5 ++--- ...ddleware.cs => AuthorizationMiddleware.cs} | 9 ++++---- ...Exception.cs => AuthorizationException.cs} | 22 +++++++++---------- .../McpServerAuthorizationExtensions.cs | 2 +- ...tionOptions.cs => AuthorizationOptions.cs} | 2 +- .../Auth/DefaultAuthorizationHandler.cs | 18 +++++++-------- ...der.cs => IServerAuthorizationProvider.cs} | 2 +- .../Protocol/Transport/SseClientTransport.cs | 4 ++-- .../Transport/SseClientTransportOptions.cs | 2 +- .../Protocol/Types/AuthorizationCapability.cs | 2 +- .../Auth/SimpleServerAuthorizationProvider.cs | 4 ++-- 12 files changed, 36 insertions(+), 38 deletions(-) rename src/ModelContextProtocol.AspNetCore/{McpAuthorizationExtensions.cs => AuthorizationExtensions.cs} (80%) rename src/ModelContextProtocol.AspNetCore/{McpAuthorizationMiddleware.cs => AuthorizationMiddleware.cs} (93%) rename src/ModelContextProtocol/{McpAuthorizationException.cs => AuthorizationException.cs} (65%) rename src/ModelContextProtocol/Protocol/Auth/{McpAuthorizationOptions.cs => AuthorizationOptions.cs} (98%) rename src/ModelContextProtocol/Protocol/Auth/{IMcpServerAuthorizationProvider.cs => IServerAuthorizationProvider.cs} (95%) diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 1bc2cc56..54da48c2 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -24,7 +24,7 @@ public static async Task Main(string[] args) var transportOptions = new SseClientTransportOptions { Endpoint = serverEndpoint, - AuthorizationOptions = new McpAuthorizationOptions + AuthorizationOptions = new AuthorizationOptions { // Pre-registered client credentials (if applicable) ClientId = "04f79824-ab56-4511-a7cb-d7deaea92dc0", diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationExtensions.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs similarity index 80% rename from src/ModelContextProtocol.AspNetCore/McpAuthorizationExtensions.cs rename to src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs index b4f0a136..16f93e98 100644 --- a/src/ModelContextProtocol.AspNetCore/McpAuthorizationExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs @@ -1,12 +1,11 @@ using Microsoft.AspNetCore.Builder; -using Microsoft.Extensions.DependencyInjection; namespace ModelContextProtocol.AspNetCore; /// /// Extension methods for using MCP authorization in ASP.NET Core applications. /// -public static class McpAuthorizationExtensions +public static class AuthorizationExtensions { /// /// Adds MCP authorization middleware to the specified , which enables @@ -16,6 +15,6 @@ public static class McpAuthorizationExtensions /// A reference to this instance after the operation has completed. public static IApplicationBuilder UseMcpAuthorization(this IApplicationBuilder builder) { - return builder.UseMiddleware(); + return builder.UseMiddleware(); } } \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationMiddleware.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs similarity index 93% rename from src/ModelContextProtocol.AspNetCore/McpAuthorizationMiddleware.cs rename to src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs index f128c2fc..ce7f9b69 100644 --- a/src/ModelContextProtocol.AspNetCore/McpAuthorizationMiddleware.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs @@ -2,7 +2,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using ModelContextProtocol.Protocol.Auth; -using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Utils.Json; using System.Text.Json; @@ -12,17 +11,17 @@ namespace ModelContextProtocol.AspNetCore; /// /// Middleware that handles authorization for MCP servers. /// -internal class McpAuthorizationMiddleware +internal class AuthorizationMiddleware { private readonly RequestDelegate _next; - private readonly ILogger _logger; + private readonly ILogger _logger; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The next middleware in the pipeline. /// The logger factory. - public McpAuthorizationMiddleware(RequestDelegate next, ILogger logger) + public AuthorizationMiddleware(RequestDelegate next, ILogger logger) { _next = next ?? throw new ArgumentNullException(nameof(next)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); diff --git a/src/ModelContextProtocol/McpAuthorizationException.cs b/src/ModelContextProtocol/AuthorizationException.cs similarity index 65% rename from src/ModelContextProtocol/McpAuthorizationException.cs rename to src/ModelContextProtocol/AuthorizationException.cs index 93eb4679..893b5ef8 100644 --- a/src/ModelContextProtocol/McpAuthorizationException.cs +++ b/src/ModelContextProtocol/AuthorizationException.cs @@ -7,52 +7,52 @@ namespace ModelContextProtocol; /// This exception is thrown when the client fails to authenticate with an MCP server that requires /// authentication, such as when the OAuth authorization flow fails or when the server rejects the provided credentials. /// -public class McpAuthorizationException : McpException +public class AuthorizationException : McpException { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// - public McpAuthorizationException() + public AuthorizationException() : base("Authorization failed", McpErrorCode.InvalidRequest) { } /// - /// Initializes a new instance of the class with a specified error message. + /// Initializes a new instance of the class with a specified error message. /// /// The message that describes the error. - public McpAuthorizationException(string message) + public AuthorizationException(string message) : base(message, McpErrorCode.InvalidRequest) { } /// - /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. /// /// The message that describes the error. /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. - public McpAuthorizationException(string message, Exception? innerException) + public AuthorizationException(string message, Exception? innerException) : base(message, innerException, McpErrorCode.InvalidRequest) { } /// - /// Initializes a new instance of the class with a specified error message and error code. + /// Initializes a new instance of the class with a specified error message and error code. /// /// The message that describes the error. /// The MCP error code. Should use one of the standard error codes. - public McpAuthorizationException(string message, McpErrorCode errorCode) + public AuthorizationException(string message, McpErrorCode errorCode) : base(message, errorCode) { } /// - /// Initializes a new instance of the class with a specified error message, inner exception, and error code. + /// Initializes a new instance of the class with a specified error message, inner exception, and error code. /// /// The message that describes the error. /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. /// The MCP error code. Should use one of the standard error codes. - public McpAuthorizationException(string message, Exception? innerException, McpErrorCode errorCode) + public AuthorizationException(string message, Exception? innerException, McpErrorCode errorCode) : base(message, innerException, errorCode) { } diff --git a/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs index 16ea0660..ab29813e 100644 --- a/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs @@ -19,7 +19,7 @@ public static class McpServerAuthorizationExtensions /// or is . public static IMcpServerBuilder WithAuthorization( this IMcpServerBuilder builder, - IMcpServerAuthorizationProvider authorizationProvider) + IServerAuthorizationProvider authorizationProvider) { Throw.IfNull(builder); Throw.IfNull(authorizationProvider); diff --git a/src/ModelContextProtocol/Protocol/Auth/McpAuthorizationOptions.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs similarity index 98% rename from src/ModelContextProtocol/Protocol/Auth/McpAuthorizationOptions.cs rename to src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs index 85af0a94..506309c5 100644 --- a/src/ModelContextProtocol/Protocol/Auth/McpAuthorizationOptions.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// /// Provides authorization options for MCP clients. /// -public class McpAuthorizationOptions +public class AuthorizationOptions { /// /// Gets or sets a delegate that handles the OAuth 2.0 authorization code flow. diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs index ce240e8c..80edc6a2 100644 --- a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -25,7 +25,7 @@ internal class DefaultAuthorizationHandler : IAuthorizationHandler /// /// The logger factory. /// The authorization options. - public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, McpAuthorizationOptions? options = null) + public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, AuthorizationOptions? options = null) { _logger = loggerFactory != null ? loggerFactory.CreateLogger() @@ -47,7 +47,7 @@ public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, McpAuth /// The logger factory. /// A callback function that handles the authorization code flow. public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, Func>? authorizeCallback = null) - : this(loggerFactory, new McpAuthorizationOptions { AuthorizeCallback = authorizeCallback }) + : this(loggerFactory, new AuthorizationOptions { AuthorizeCallback = authorizeCallback }) { } @@ -90,7 +90,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp _logger.LogWarning("Failed to extract resource metadata from 401 response"); // Create a more specific exception - var exception = new McpAuthorizationException("Authorization required but no resource metadata available") + var exception = new AuthorizationException("Authorization required but no resource metadata available") { ResourceUri = serverUri.ToString() }; @@ -106,7 +106,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp _logger.LogWarning("Resource URL mismatch: expected {Expected}, got {Actual}", serverUri, resourceMetadata.Resource); - var exception = new McpAuthorizationException($"Resource URL mismatch: expected {serverUri}, got {resourceMetadata.Resource}"); + var exception = new AuthorizationException($"Resource URL mismatch: expected {serverUri}, got {resourceMetadata.Resource}"); exception.ResourceUri = resourceMetadata.Resource; throw exception; } @@ -116,7 +116,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp { _logger.LogWarning("No authorization servers found in resource metadata"); - var exception = new McpAuthorizationException("No authorization servers available"); + var exception = new AuthorizationException("No authorization servers available"); exception.ResourceUri = resourceMetadata.Resource; throw exception; } @@ -166,7 +166,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp { _logger.LogWarning("Authorization server does not support dynamic client registration and no client ID was provided"); - var exception = new McpAuthorizationException( + var exception = new AuthorizationException( "Authorization server does not support dynamic client registration and no client ID was provided. " + "Use McpAuthorizationOptions.ClientId to provide a pre-registered client ID."); exception.ResourceUri = resourceMetadata.Resource; @@ -179,7 +179,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp { _logger.LogWarning("No authorization callback provided, can't proceed with OAuth flow"); - var exception = new McpAuthorizationException( + var exception = new AuthorizationException( "Authentication is required but no authorization callback was provided. " + "Use McpAuthorizationOptions.AuthorizeCallback to provide a callback function."); exception.ResourceUri = resourceMetadata.Resource; @@ -227,11 +227,11 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp _logger.LogDebug("Successfully obtained access token"); return true; } - catch (Exception ex) when (ex is not McpAuthorizationException) + catch (Exception ex) when (ex is not AuthorizationException) { _logger.LogError(ex, "Failed to complete authorization flow"); - var authException = new McpAuthorizationException( + var authException = new AuthorizationException( $"Failed to complete authorization flow: {ex.Message}", ex, McpErrorCode.InvalidRequest); authException.ResourceUri = resourceMetadata.Resource; diff --git a/src/ModelContextProtocol/Protocol/Auth/IMcpServerAuthorizationProvider.cs b/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs similarity index 95% rename from src/ModelContextProtocol/Protocol/Auth/IMcpServerAuthorizationProvider.cs rename to src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs index 53f4ca2e..3d34def9 100644 --- a/src/ModelContextProtocol/Protocol/Auth/IMcpServerAuthorizationProvider.cs +++ b/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs @@ -9,7 +9,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// This interface is implemented by authorization providers that enable MCP servers to validate tokens /// and control access to protected resources. /// -public interface IMcpServerAuthorizationProvider +public interface IServerAuthorizationProvider { /// /// Gets the Protected Resource Metadata (PRM) for the server. diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index c4d6c9de..ffac51f7 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -69,7 +69,7 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient /// The path for the redirect URI. Default is "/callback". /// The HTML content to display on successful authorization. If null, a default message is shown. /// The HTML template to display on failed authorization. If null, a default message is shown. Use {0} as a placeholder for the error message. - /// A delegate that can be used for the property. + /// A delegate that can be used for the property. /// /// /// This method creates a delegate that implements a complete OAuth 2.0 authorization code flow using an HTTP listener. @@ -211,7 +211,7 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient /// A function that opens a URL in the browser. /// The local port to listen on for the redirect URI. /// The path for the redirect URI. - /// A delegate that can be used for the property. + /// A delegate that can be used for the property. /// /// This is a convenience method that calls with "localhost" as the hostname. /// diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs index 28420bb0..cb1797b3 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs @@ -88,5 +88,5 @@ public required Uri Endpoint /// /// /// - public McpAuthorizationOptions? AuthorizationOptions { get; init; } + public AuthorizationOptions? AuthorizationOptions { get; init; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs b/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs index b2361678..47557b00 100644 --- a/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs +++ b/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs @@ -15,5 +15,5 @@ public class AuthorizationCapability /// Gets or sets the authorization provider that handles token validation and provides /// metadata about the protected resource. /// - public IMcpServerAuthorizationProvider? AuthorizationProvider { get; set; } + public IServerAuthorizationProvider? AuthorizationProvider { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs b/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs index 925da2cd..f99fb154 100644 --- a/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs +++ b/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs @@ -4,14 +4,14 @@ namespace ModelContextProtocol.Server.Auth; /// -/// A simple implementation of that validates bearer tokens. +/// A simple implementation of that validates bearer tokens. /// /// /// This implementation is intended as a starting point for server developers. In production environments, /// it should be extended or replaced with a more robust implementation that integrates with your /// authentication system (e.g., OAuth 2.0 server, identity provider, etc.) /// -public class SimpleServerAuthorizationProvider : IMcpServerAuthorizationProvider +public class SimpleServerAuthorizationProvider : IServerAuthorizationProvider { private readonly ProtectedResourceMetadata _resourceMetadata; private readonly Func> _tokenValidator; From b0d9932bea5cc3aed352edddd1ad09a362809b17 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 22:23:22 -0700 Subject: [PATCH 11/28] No need to keep track of this --- samples/AuthorizationExample/Program.cs | 4 ++-- .../Protocol/Transport/SseClientTransport.cs | 19 ------------------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 54da48c2..7e23bc13 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -17,8 +17,8 @@ public static async Task Main(string[] args) // Configuration values for OAuth redirect string hostname = "localhost"; - int port = 8888; - string callbackPath = "/oauth/callback"; + int port = 13261; + string callbackPath = "/oauth/callback/"; // Set up the SSE transport with authorization support var transportOptions = new SseClientTransportOptions diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index ffac51f7..b6a233d4 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -205,25 +205,6 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient }; } - /// - /// Creates a delegate that can handle the OAuth 2.0 authorization code flow using a local HTTP listener. - /// - /// A function that opens a URL in the browser. - /// The local port to listen on for the redirect URI. - /// The path for the redirect URI. - /// A delegate that can be used for the property. - /// - /// This is a convenience method that calls with "localhost" as the hostname. - /// - [Obsolete("Use CreateHttpListenerAuthorizeCallback instead. This method will be removed in a future version.")] - public static Func> CreateLocalServerAuthorizeCallback( - Func openBrowser, - int listenPort = 8888, - string redirectPath = "/callback") - { - return CreateHttpListenerAuthorizeCallback(openBrowser, "localhost", listenPort, redirectPath); - } - /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { From e6c199591c552d8fd2ac1ce7b3beb86f46d3e336 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 23 Apr 2025 23:23:41 -0700 Subject: [PATCH 12/28] Updated logc --- samples/AuthorizationExample/Program.cs | 3 +- .../Protocol/Transport/SseClientTransport.cs | 221 ++++++++++-------- 2 files changed, 119 insertions(+), 105 deletions(-) diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 7e23bc13..15ce6f65 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -47,8 +47,7 @@ public static async Task Main(string[] args) }, hostname: hostname, listenPort: port, - redirectPath: callbackPath, - successHtml: "

Authorization Successful

You have successfully authorized the application. You can close this window and return to the app.

" + redirectPath: callbackPath ) } }; diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index b6a233d4..8e83acad 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -1,6 +1,8 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Utils; +using System.Net; +using System.Text; namespace ModelContextProtocol.Protocol.Transport; @@ -61,147 +63,160 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient public string Name { get; } /// - /// Creates a delegate that can handle the OAuth 2.0 authorization code flow using an HTTP listener. + /// Creates a callback function for handling OAuth 2.0 authorization flows using an HTTP listener. /// - /// A function that opens a URL in the browser. - /// The hostname to listen on for the redirect URI. Default is "localhost". - /// The port to listen on for the redirect URI. Default is 8888. - /// The path for the redirect URI. Default is "/callback". - /// The HTML content to display on successful authorization. If null, a default message is shown. - /// The HTML template to display on failed authorization. If null, a default message is shown. Use {0} as a placeholder for the error message. - /// A delegate that can be used for the property. - /// - /// - /// This method creates a delegate that implements a complete OAuth 2.0 authorization code flow using an HTTP listener. - /// When called, it will: - /// - /// - /// Open the authorization URL in the browser - /// Start an HTTP listener to receive the authorization code - /// Return the redirect URI and authorization code when received - /// - /// - /// You can customize the hostname, port, and path for the redirect URI to match your OAuth client configuration. - /// - /// + /// A function to open the browser to the authorization URL. + /// The hostname for the HTTP listener. Defaults to "localhost". + /// The port for the HTTP listener. Defaults to 8888. + /// The redirect path for the HTTP listener. Defaults to "/callback". + /// + /// A function that takes and returns a task that resolves to a tuple containing + /// the redirect URI and the authorization code. + /// public static Func> CreateHttpListenerAuthorizeCallback( Func openBrowser, string hostname = "localhost", int listenPort = 8888, - string redirectPath = "/callback", - string? successHtml = null, - string? errorHtml = null) + string redirectPath = "/callback") { return async (ClientMetadata clientMetadata) => { - // Default redirect URI based on parameters - var defaultRedirectUri = $"http://{hostname}:{listenPort}{redirectPath}"; - - // First, try to find a matching redirect URI from the client metadata - var redirectUri = defaultRedirectUri; - var hostPrefix = $"http://{hostname}"; - + string redirectUri = $"http://{hostname}:{listenPort}{redirectPath}"; + foreach (var uri in clientMetadata.RedirectUris) { - if (uri.StartsWith(hostPrefix, StringComparison.OrdinalIgnoreCase)) + if (uri.StartsWith($"http://{hostname}", StringComparison.OrdinalIgnoreCase) && + Uri.TryCreate(uri, UriKind.Absolute, out var parsedUri)) { redirectUri = uri; - - // Parse the port and path from the selected URI to ensure we listen on the correct endpoint - if (Uri.TryCreate(uri, UriKind.Absolute, out var parsedUri)) - { - listenPort = parsedUri.IsDefaultPort ? 80 : parsedUri.Port; - redirectPath = parsedUri.AbsolutePath; - } - + listenPort = parsedUri.IsDefaultPort ? 80 : parsedUri.Port; + redirectPath = parsedUri.AbsolutePath; break; } } - - // Use a TaskCompletionSource to wait for the authorization code + var authCodeTcs = new TaskCompletionSource(); - - // Start an HTTP listener to listen for the authorization code - using var listener = new System.Net.HttpListener(); - - // Ensure the URI format is correct for HttpListener - var listenerPrefix = $"http://{hostname}:{listenPort}/"; - if (redirectPath.Length > 1) + // Ensure the path has a trailing slash for the HttpListener prefix + string listenerPrefix = $"http://{hostname}:{listenPort}{redirectPath}"; + if (!listenerPrefix.EndsWith("/")) { - // If path is something like "/callback", we need to listen on all paths that start with it - var basePath = redirectPath.TrimEnd('/').TrimStart('/'); - listenerPrefix = $"http://{hostname}:{listenPort}/{basePath}/"; + listenerPrefix += "/"; } - + + using var listener = new HttpListener(); listener.Prefixes.Add(listenerPrefix); - listener.Start(); - // Default HTML responses - var defaultSuccessHtml = "

Authorization Successful

You can now close this window and return to the application.

"; - var defaultErrorHtml = "

Authorization Failed

Error: {0}

"; + // Start the listener BEFORE opening the browser + try + { + listener.Start(); + } + catch (HttpListenerException ex) + { + throw new McpException($"Failed to start HTTP listener on {listenerPrefix}: {ex.Message}", McpErrorCode.InvalidRequest); + } + + // Create a cancellation token source with a timeout + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); - // Start listening for the callback asynchronously - var listenerTask = Task.Run(async () => + _ = Task.Run(async () => { try { - var context = await listener.GetContextAsync(); - var request = context.Request; - - // Get the authorization code from the query string - var code = request.QueryString["code"]; - var error = request.QueryString["error"]; + // GetContextAsync doesn't accept a cancellation token, so we need to handle cancellation manually + var contextTask = listener.GetContextAsync(); + var completedTask = await Task.WhenAny(contextTask, Task.Delay(Timeout.Infinite, cts.Token)); - // Send a response to the browser - var response = context.Response; - response.ContentType = "text/html"; - string responseHtml; - - if (!string.IsNullOrEmpty(error)) + if (completedTask == contextTask) { - responseHtml = string.Format(errorHtml ?? defaultErrorHtml, error); - authCodeTcs.SetException(new McpException($"Authorization failed: {error}", McpErrorCode.InvalidRequest)); - } - else if (string.IsNullOrEmpty(code)) - { - responseHtml = string.Format(errorHtml ?? defaultErrorHtml, "No authorization code received"); - authCodeTcs.SetException(new McpException("No authorization code received", McpErrorCode.InvalidRequest)); - } - else - { - responseHtml = successHtml ?? defaultSuccessHtml; - authCodeTcs.SetResult(code); + var context = await contextTask; + var request = context.Request; + var response = context.Response; + + string? code = request.QueryString["code"]; + string? error = request.QueryString["error"]; + string html; + string? resultCode = null; + + if (!string.IsNullOrEmpty(error)) + { + html = $"

Authorization Failed

Error: {WebUtility.HtmlEncode(error)}

"; + } + else if (string.IsNullOrEmpty(code)) + { + html = "

Authorization Failed

No authorization code received.

"; + } + else + { + html = "

Authorization Successful

You may now close this window.

"; + resultCode = code; + } + + try + { + // Send response to browser + byte[] buffer = Encoding.UTF8.GetBytes(html); + response.ContentType = "text/html"; + response.ContentLength64 = buffer.Length; + response.OutputStream.Write(buffer, 0, buffer.Length); + + // IMPORTANT: Explicitly close the response to ensure it's fully sent + response.Close(); + + // Now that we've finished processing the browser response, + // we can safely signal completion or failure with the auth code + if (resultCode != null) + { + authCodeTcs.TrySetResult(resultCode); + } + else if (!string.IsNullOrEmpty(error)) + { + authCodeTcs.TrySetException(new McpException($"Authorization failed: {error}", McpErrorCode.InvalidRequest)); + } + else + { + authCodeTcs.TrySetException(new McpException("No authorization code received", McpErrorCode.InvalidRequest)); + } + } + catch (Exception ex) + { + authCodeTcs.TrySetException(new McpException($"Error processing browser response: {ex.Message}", McpErrorCode.InvalidRequest)); + } } - - var buffer = System.Text.Encoding.UTF8.GetBytes(responseHtml); - response.ContentLength64 = buffer.Length; - await response.OutputStream.WriteAsync(buffer, 0, buffer.Length); - response.Close(); } catch (Exception ex) { authCodeTcs.TrySetException(ex); } - finally - { - listener.Close(); - } }); - - // Open the authorization URL in the browser - if (clientMetadata.ClientUri != null) + + // Now open the browser AFTER the listener is started + if (!string.IsNullOrEmpty(clientMetadata.ClientUri)) { - await openBrowser(clientMetadata.ClientUri); + await openBrowser(clientMetadata.ClientUri!); } else { - authCodeTcs.SetException(new McpException("No authorization URL provided in client metadata", McpErrorCode.InvalidRequest)); + // Stop the listener before throwing + listener.Stop(); + throw new McpException("Client URI is missing in metadata.", McpErrorCode.InvalidRequest); + } + + try + { + // Use a timeout to avoid hanging indefinitely + string authCode = await authCodeTcs.Task.WaitAsync(cts.Token); + return (redirectUri, authCode); + } + catch (OperationCanceledException) + { + throw new McpException("Authorization timed out after 5 minutes.", McpErrorCode.InvalidRequest); + } + finally + { + // Ensure the listener is stopped when we're done + listener.Stop(); } - - // Wait for the authorization code - var code = await authCodeTcs.Task; - - return (redirectUri, code); }; } From 2f44765b267756175764294a453969e8a9776ac2 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 24 Apr 2025 00:02:24 -0700 Subject: [PATCH 13/28] Update with proper token logic --- samples/AuthorizationExample/Program.cs | 2 + samples/AuthorizationServerExample/Program.cs | 7 +- .../AuthorizationMiddleware.cs | 9 +- .../McpServerAuthorizationExtensions.cs | 7 +- .../Auth/DefaultAuthorizationHandler.cs | 5 +- .../Transport/SseClientSessionTransport.cs | 110 ++++++------------ .../Protocol/Types/ServerCapabilities.cs | 6 - src/ModelContextProtocol/Server/McpServer.cs | 2 - 8 files changed, 50 insertions(+), 98 deletions(-) diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 15ce6f65..35f46def 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -86,6 +86,8 @@ public static async Task Main(string[] args) { Console.WriteLine($"Inner Error: {ex.InnerException.Message}"); } + // Print the stack trace for debugging + Console.WriteLine($"Stack Trace:\n{ex.StackTrace}"); } } } \ No newline at end of file diff --git a/samples/AuthorizationServerExample/Program.cs b/samples/AuthorizationServerExample/Program.cs index 97dfe309..ac375fff 100644 --- a/samples/AuthorizationServerExample/Program.cs +++ b/samples/AuthorizationServerExample/Program.cs @@ -35,11 +35,8 @@ public static async Task Main(string[] args) // In a real application, this would verify the token with your identity provider async Task ValidateToken(string token) { - // For demo purposes, we'll accept any token that starts with "valid_" - // In production, you would validate the token with your identity provider - var isValid = token.StartsWith("valid_", StringComparison.OrdinalIgnoreCase); - Console.WriteLine($"Token validation result: {(isValid ? "Valid" : "Invalid")}"); - return isValid; + // For demo purposes, we'll accept any token. + return true; } // 3. Create an authorization provider with the PRM and token validator diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs index ce7f9b69..0c6ad374 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs @@ -32,13 +32,14 @@ public AuthorizationMiddleware(RequestDelegate next, ILogger /// The HTTP context. /// The MCP server options. + /// The authorization provider. /// A representing the asynchronous operation. - public async Task InvokeAsync(HttpContext context, IOptions serverOptions) + public async Task InvokeAsync( + HttpContext context, + IOptions serverOptions, + IServerAuthorizationProvider? authProvider = null) { // Check if authorization is configured - var authCapability = serverOptions.Value.Capabilities?.Authorization; - var authProvider = authCapability?.AuthorizationProvider; - if (authProvider == null) { // Authorization is not configured, proceed to the next middleware diff --git a/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs index ab29813e..7592b5e9 100644 --- a/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs @@ -24,13 +24,12 @@ public static IMcpServerBuilder WithAuthorization( Throw.IfNull(builder); Throw.IfNull(authorizationProvider); + // Register the authorization provider in the DI container + builder.Services.AddSingleton(authorizationProvider); + builder.Services.Configure(options => { options.Capabilities ??= new ServerCapabilities(); - options.Capabilities.Authorization = new AuthorizationCapability - { - AuthorizationProvider = authorizationProvider - }; }); return builder; diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs index 80edc6a2..326f82fb 100644 --- a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -1,9 +1,8 @@ -using System.Diagnostics; -using System.Net; -using System.Net.Http.Headers; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Utils; +using System.Net; +using System.Net.Http.Headers; namespace ModelContextProtocol.Protocol.Auth; diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 2f5ad8f3..aa341880 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -89,29 +89,12 @@ public override async Task SendMessageAsync( if (_messageEndpoint == null) throw new InvalidOperationException("Transport not connected"); - using var content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json" - ); - string messageId = "(no id)"; if (message is JsonRpcMessageWithId messageWithId) { messageId = messageWithId.Id.ToString(); } - - using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) - { - Content = content, - }; - - // Add authorization headers if needed - await _authorizationHandler.AuthenticateRequestAsync(httpRequestMessage).ConfigureAwait(false); - - // Copy additional headers - CopyAdditionalHeaders(httpRequestMessage.Headers); // Send the request, handling potential auth challenges HttpResponseMessage? response = null; @@ -120,37 +103,32 @@ public override async Task SendMessageAsync( do { authRetry = false; - response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); - // Handle 401 Unauthorized response + // Create a new request for each attempt + using var currentRequest = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint); + currentRequest.Content = new StringContent( + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), + Encoding.UTF8, + "application/json" + ); + + // Add authorization headers if needed - the handler will only add headers if auth is required + await _authorizationHandler.AuthenticateRequestAsync(currentRequest).ConfigureAwait(false); + + // Copy additional headers + CopyAdditionalHeaders(currentRequest.Headers); + + // Dispose previous response before making a new request + response?.Dispose(); + + response = await _httpClient.SendAsync(currentRequest, cancellationToken).ConfigureAwait(false); + + // Handle 401 Unauthorized response - this will only execute if the server requires auth if (response.StatusCode == HttpStatusCode.Unauthorized) { // Try to handle the unauthorized response authRetry = await _authorizationHandler.HandleUnauthorizedResponseAsync( response, _messageEndpoint).ConfigureAwait(false); - - if (authRetry) - { - // Create a new request (we can't reuse the previous one) - using var newRequest = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) - { - Content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json" - ) - }; - - // Add authorization headers for the new request - await _authorizationHandler.AuthenticateRequestAsync(newRequest).ConfigureAwait(false); - CopyAdditionalHeaders(newRequest.Headers); - - // Dispose the previous response - response.Dispose(); - - // Send the new request - response = await _httpClient.SendAsync(newRequest, cancellationToken).ConfigureAwait(false); - } } } while (authRetry); @@ -252,15 +230,6 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { try { - using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); - request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - - // Add authorization headers if needed - await _authorizationHandler.AuthenticateRequestAsync(request).ConfigureAwait(false); - - // Copy additional headers - CopyAdditionalHeaders(request.Headers); - // Send the request, handling potential auth challenges HttpResponseMessage? response = null; bool authRetry = false; @@ -268,39 +237,32 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) do { authRetry = false; + + // Create a new request for each attempt + using var currentRequest = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); + currentRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + + // Add authorization headers if needed - the handler will only add headers if auth is required + await _authorizationHandler.AuthenticateRequestAsync(currentRequest).ConfigureAwait(false); + + // Copy additional headers + CopyAdditionalHeaders(currentRequest.Headers); + + // Dispose previous response before making a new request + response?.Dispose(); + response = await _httpClient.SendAsync( - request, + currentRequest, HttpCompletionOption.ResponseHeadersRead, cancellationToken ).ConfigureAwait(false); - // Handle 401 Unauthorized response + // Handle 401 Unauthorized response - this will only execute if the server requires auth if (response.StatusCode == HttpStatusCode.Unauthorized) { // Try to handle the unauthorized response authRetry = await _authorizationHandler.HandleUnauthorizedResponseAsync( response, _sseEndpoint).ConfigureAwait(false); - - if (authRetry) - { - // Create a new request (we can't reuse the previous one) - using var newRequest = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); - newRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - - // Add authorization headers for the new request - await _authorizationHandler.AuthenticateRequestAsync(newRequest).ConfigureAwait(false); - CopyAdditionalHeaders(newRequest.Headers); - - // Dispose the previous response - response.Dispose(); - - // Send the new request - response = await _httpClient.SendAsync( - newRequest, - HttpCompletionOption.ResponseHeadersRead, - cancellationToken - ).ConfigureAwait(false); - } } } while (authRetry); diff --git a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs index 8e524845..6406ea4d 100644 --- a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs @@ -35,12 +35,6 @@ public class ServerCapabilities [JsonPropertyName("experimental")] public Dictionary? Experimental { get; set; } - /// - /// Gets or sets a server's authorization capability, supporting OAuth 2.0 authorization flows. - /// - [JsonPropertyName("authorization")] - public AuthorizationCapability? Authorization { get; set; } - /// /// Gets or sets a server's logging capability, supporting sending log messages to the client. /// diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index aeee62dd..7faad5c9 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -328,7 +328,6 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals ServerCapabilities = new() { Experimental = options.Capabilities?.Experimental, - Authorization = options.Capabilities?.Authorization, Logging = options.Capabilities?.Logging, Tools = options.Capabilities?.Tools, Resources = options.Capabilities?.Resources, @@ -427,7 +426,6 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) ServerCapabilities = new() { Experimental = options.Capabilities?.Experimental, - Authorization = options.Capabilities?.Authorization, Logging = options.Capabilities?.Logging, Prompts = options.Capabilities?.Prompts, Resources = options.Capabilities?.Resources, From bf9f63eba402561449852410d4c8c05c288eb5b7 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 24 Apr 2025 00:05:58 -0700 Subject: [PATCH 14/28] Cleanup of unused declarations --- .../Protocol/Auth/AuthorizationContext.cs | 2 -- .../Protocol/Auth/AuthorizationOptions.cs | 3 -- .../Protocol/Auth/AuthorizationService.cs | 4 +-- .../Protocol/Auth/IAuthorizationHandler.cs | 2 -- .../Auth/IServerAuthorizationProvider.cs | 2 -- .../Auth/SimpleServerAuthorizationProvider.cs | 29 +++++++------------ 6 files changed, 13 insertions(+), 29 deletions(-) diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs index c210fd9e..df4c8e59 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs @@ -1,5 +1,3 @@ -using System.Diagnostics; - namespace ModelContextProtocol.Protocol.Auth; /// diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs index 506309c5..232de02e 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs @@ -1,6 +1,3 @@ -using System; -using System.Collections.Generic; - namespace ModelContextProtocol.Protocol.Auth; /// diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs index a6bc5807..4ff7222d 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs @@ -1,10 +1,10 @@ +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; using System.Net; using System.Net.Http.Headers; using System.Security.Cryptography; using System.Text; using System.Text.Json; -using ModelContextProtocol.Utils; -using ModelContextProtocol.Utils.Json; namespace ModelContextProtocol.Protocol.Auth; diff --git a/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs index 85f5a61d..ffa41acb 100644 --- a/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs +++ b/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs @@ -1,5 +1,3 @@ -using System.Net; - namespace ModelContextProtocol.Protocol.Auth; /// diff --git a/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs b/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs index 3d34def9..4aee3f22 100644 --- a/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs +++ b/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs @@ -1,5 +1,3 @@ -using System.Text.Json; - namespace ModelContextProtocol.Protocol.Auth; /// diff --git a/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs b/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs index f99fb154..fdd9b925 100644 --- a/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs +++ b/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs @@ -1,5 +1,4 @@ using ModelContextProtocol.Protocol.Auth; -using System.Security.Claims; namespace ModelContextProtocol.Server.Auth; @@ -11,24 +10,18 @@ namespace ModelContextProtocol.Server.Auth; /// it should be extended or replaced with a more robust implementation that integrates with your /// authentication system (e.g., OAuth 2.0 server, identity provider, etc.) /// -public class SimpleServerAuthorizationProvider : IServerAuthorizationProvider +/// +/// Initializes a new instance of the class +/// with the specified resource metadata and token validator. +/// +/// The protected resource metadata. +/// A function that validates access tokens. If not provided, a function that always returns true will be used. +public class SimpleServerAuthorizationProvider( + ProtectedResourceMetadata resourceMetadata, + Func>? tokenValidator = null) : IServerAuthorizationProvider { - private readonly ProtectedResourceMetadata _resourceMetadata; - private readonly Func> _tokenValidator; - - /// - /// Initializes a new instance of the class - /// with the specified resource metadata and token validator. - /// - /// The protected resource metadata. - /// A function that validates access tokens. If not provided, a function that always returns true will be used. - public SimpleServerAuthorizationProvider( - ProtectedResourceMetadata resourceMetadata, - Func>? tokenValidator = null) - { - _resourceMetadata = resourceMetadata ?? throw new ArgumentNullException(nameof(resourceMetadata)); - _tokenValidator = tokenValidator ?? (_ => Task.FromResult(true)); - } + private readonly ProtectedResourceMetadata _resourceMetadata = resourceMetadata ?? throw new ArgumentNullException(nameof(resourceMetadata)); + private readonly Func> _tokenValidator = tokenValidator ?? (_ => Task.FromResult(true)); /// public ProtectedResourceMetadata GetProtectedResourceMetadata() => _resourceMetadata; From 3fd7681b06f23de25092bcbd9400cc4707421935 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Wed, 30 Apr 2025 23:50:03 -0700 Subject: [PATCH 15/28] Remove handler from transport definition --- samples/AuthorizationExample/Program.cs | 6 +- .../Auth/AuthorizationServerMetadata.cs | 2 +- .../Protocol/Auth/AuthorizationService.cs | 163 +++++++++++++++++- .../Auth/ClientRegistrationResponse.cs | 2 +- .../Protocol/Auth/ResourceMetadata.cs | 2 +- .../Protocol/Auth/TokenResponse.cs | 2 +- .../Protocol/Transport/SseClientTransport.cs | 160 ----------------- .../Transport/SseClientTransportOptions.cs | 5 +- 8 files changed, 168 insertions(+), 174 deletions(-) diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 35f46def..42b1c83c 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -36,10 +36,8 @@ public static async Task Main(string[] args) RedirectUris = new[] { $"http://{hostname}:{port}{callbackPath}" - }, - - // Configure the authorize callback with the same hostname, port, and path - AuthorizeCallback = SseClientTransport.CreateHttpListenerAuthorizeCallback( + }, // Configure the authorize callback with the same hostname, port, and path + AuthorizeCallback = AuthorizationService.CreateHttpListenerAuthorizeCallback( openBrowser: async (url) => { Console.WriteLine($"Opening browser to authorize at: {url}"); diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs index 9be69e67..56ce385f 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// /// Represents OAuth 2.0 authorization server metadata as defined in RFC 8414. /// -internal class AuthorizationServerMetadata +public class AuthorizationServerMetadata { /// /// Gets or sets the authorization endpoint URL. diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs index 4ff7222d..8db6028a 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs @@ -11,7 +11,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// /// Provides OAuth 2.0 authorization services for MCP clients. /// -internal class AuthorizationService +public class AuthorizationService { private static readonly HttpClient s_httpClient = new() { @@ -448,9 +448,166 @@ private static Dictionary ParseAuthHeaderParameters(string param break; start = commaPos + 1; - } - } + } } return result; } + + /// + /// Creates an HTTP listener callback for handling OAuth 2.0 authorization code flow. + /// + /// A function that opens a browser with the given URL. + /// The hostname to listen on. Defaults to "localhost". + /// The port to listen on. Defaults to 8888. + /// The redirect path for the HTTP listener. Defaults to "/callback". + /// + /// A function that takes and returns a task that resolves to a tuple containing + /// the redirect URI and the authorization code. + /// + public static Func> CreateHttpListenerAuthorizeCallback( + Func openBrowser, + string hostname = "localhost", + int listenPort = 8888, + string redirectPath = "/callback") + { + return async (ClientMetadata clientMetadata) => + { + string redirectUri = $"http://{hostname}:{listenPort}{redirectPath}"; + + foreach (var uri in clientMetadata.RedirectUris) + { + if (uri.StartsWith($"http://{hostname}", StringComparison.OrdinalIgnoreCase) && + Uri.TryCreate(uri, UriKind.Absolute, out var parsedUri)) + { + redirectUri = uri; + listenPort = parsedUri.IsDefaultPort ? 80 : parsedUri.Port; + redirectPath = parsedUri.AbsolutePath; + break; + } + } + + var authCodeTcs = new TaskCompletionSource(); + // Ensure the path has a trailing slash for the HttpListener prefix + string listenerPrefix = $"http://{hostname}:{listenPort}{redirectPath}"; + if (!listenerPrefix.EndsWith("/")) + { + listenerPrefix += "/"; + } + + using var listener = new HttpListener(); + listener.Prefixes.Add(listenerPrefix); + + // Start the listener BEFORE opening the browser + try + { + listener.Start(); + } + catch (HttpListenerException ex) + { + throw new McpException($"Failed to start HTTP listener on {listenerPrefix}: {ex.Message}", McpErrorCode.InvalidRequest); + } + + // Create a cancellation token source with a timeout + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); + + _ = Task.Run(async () => + { + try + { + // GetContextAsync doesn't accept a cancellation token, so we need to handle cancellation manually + var contextTask = listener.GetContextAsync(); + var completedTask = await Task.WhenAny(contextTask, Task.Delay(Timeout.Infinite, cts.Token)); + + if (completedTask == contextTask) + { + var context = await contextTask; + var request = context.Request; + var response = context.Response; + + string? code = request.QueryString["code"]; + string? error = request.QueryString["error"]; + string html; + string? resultCode = null; + + if (!string.IsNullOrEmpty(error)) + { + html = $"

Authorization Failed

Error: {WebUtility.HtmlEncode(error)}

"; + } + else if (string.IsNullOrEmpty(code)) + { + html = "

Authorization Failed

No authorization code received.

"; + } + else + { + html = "

Authorization Successful

You may now close this window.

"; + resultCode = code; + } + + try + { + // Send response to browser + byte[] buffer = Encoding.UTF8.GetBytes(html); + response.ContentType = "text/html"; + response.ContentLength64 = buffer.Length; + response.OutputStream.Write(buffer, 0, buffer.Length); + + // IMPORTANT: Explicitly close the response to ensure it's fully sent + response.Close(); + + // Now that we've finished processing the browser response, + // we can safely signal completion or failure with the auth code + if (resultCode != null) + { + authCodeTcs.TrySetResult(resultCode); + } + else if (!string.IsNullOrEmpty(error)) + { + authCodeTcs.TrySetException(new McpException($"Authorization failed: {error}", McpErrorCode.InvalidRequest)); + } + else + { + authCodeTcs.TrySetException(new McpException("No authorization code received", McpErrorCode.InvalidRequest)); + } + } + catch (Exception ex) + { + authCodeTcs.TrySetException(new McpException($"Error processing browser response: {ex.Message}", McpErrorCode.InvalidRequest)); + } + } + } + catch (Exception ex) + { + authCodeTcs.TrySetException(ex); + } + }); + + // Now open the browser AFTER the listener is started + if (!string.IsNullOrEmpty(clientMetadata.ClientUri)) + { + await openBrowser(clientMetadata.ClientUri!); + } + else + { + // Stop the listener before throwing + listener.Stop(); + throw new McpException("Client URI is missing in metadata.", McpErrorCode.InvalidRequest); + } + + try + { + // Use a timeout to avoid hanging indefinitely + string authCode = await authCodeTcs.Task.WaitAsync(cts.Token); + return (redirectUri, authCode); + } + catch (OperationCanceledException) + { + throw new McpException("Authorization timed out after 5 minutes.", McpErrorCode.InvalidRequest); + } + finally + { + // Ensure the listener is stopped when we're done + listener.Stop(); + } + }; + } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs b/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs index 06cef8b5..d7042b3a 100644 --- a/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs +++ b/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// /// Represents the OAuth 2.0 client registration response as defined in RFC 7591. /// -internal class ClientRegistrationResponse +public class ClientRegistrationResponse { /// /// Gets or sets the OAuth 2.0 client identifier string. diff --git a/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs index bf6613a1..a57456c3 100644 --- a/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs +++ b/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// /// Represents the resource metadata from the WWW-Authenticate header in a 401 Unauthorized response. /// -internal class ResourceMetadata +public class ResourceMetadata { /// /// Gets or sets the resource identifier URI. diff --git a/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs b/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs index 2c5faefe..d6b33489 100644 --- a/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs +++ b/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// /// Represents the OAuth 2.0 token response as defined in RFC 6749. /// -internal class TokenResponse +public class TokenResponse { /// /// Gets or sets the access token issued by the authorization server. diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 8e83acad..56889724 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -1,8 +1,6 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Utils; -using System.Net; -using System.Text; namespace ModelContextProtocol.Protocol.Transport; @@ -62,164 +60,6 @@ public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient /// public string Name { get; } - /// - /// Creates a callback function for handling OAuth 2.0 authorization flows using an HTTP listener. - /// - /// A function to open the browser to the authorization URL. - /// The hostname for the HTTP listener. Defaults to "localhost". - /// The port for the HTTP listener. Defaults to 8888. - /// The redirect path for the HTTP listener. Defaults to "/callback". - /// - /// A function that takes and returns a task that resolves to a tuple containing - /// the redirect URI and the authorization code. - /// - public static Func> CreateHttpListenerAuthorizeCallback( - Func openBrowser, - string hostname = "localhost", - int listenPort = 8888, - string redirectPath = "/callback") - { - return async (ClientMetadata clientMetadata) => - { - string redirectUri = $"http://{hostname}:{listenPort}{redirectPath}"; - - foreach (var uri in clientMetadata.RedirectUris) - { - if (uri.StartsWith($"http://{hostname}", StringComparison.OrdinalIgnoreCase) && - Uri.TryCreate(uri, UriKind.Absolute, out var parsedUri)) - { - redirectUri = uri; - listenPort = parsedUri.IsDefaultPort ? 80 : parsedUri.Port; - redirectPath = parsedUri.AbsolutePath; - break; - } - } - - var authCodeTcs = new TaskCompletionSource(); - // Ensure the path has a trailing slash for the HttpListener prefix - string listenerPrefix = $"http://{hostname}:{listenPort}{redirectPath}"; - if (!listenerPrefix.EndsWith("/")) - { - listenerPrefix += "/"; - } - - using var listener = new HttpListener(); - listener.Prefixes.Add(listenerPrefix); - - // Start the listener BEFORE opening the browser - try - { - listener.Start(); - } - catch (HttpListenerException ex) - { - throw new McpException($"Failed to start HTTP listener on {listenerPrefix}: {ex.Message}", McpErrorCode.InvalidRequest); - } - - // Create a cancellation token source with a timeout - using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); - - _ = Task.Run(async () => - { - try - { - // GetContextAsync doesn't accept a cancellation token, so we need to handle cancellation manually - var contextTask = listener.GetContextAsync(); - var completedTask = await Task.WhenAny(contextTask, Task.Delay(Timeout.Infinite, cts.Token)); - - if (completedTask == contextTask) - { - var context = await contextTask; - var request = context.Request; - var response = context.Response; - - string? code = request.QueryString["code"]; - string? error = request.QueryString["error"]; - string html; - string? resultCode = null; - - if (!string.IsNullOrEmpty(error)) - { - html = $"

Authorization Failed

Error: {WebUtility.HtmlEncode(error)}

"; - } - else if (string.IsNullOrEmpty(code)) - { - html = "

Authorization Failed

No authorization code received.

"; - } - else - { - html = "

Authorization Successful

You may now close this window.

"; - resultCode = code; - } - - try - { - // Send response to browser - byte[] buffer = Encoding.UTF8.GetBytes(html); - response.ContentType = "text/html"; - response.ContentLength64 = buffer.Length; - response.OutputStream.Write(buffer, 0, buffer.Length); - - // IMPORTANT: Explicitly close the response to ensure it's fully sent - response.Close(); - - // Now that we've finished processing the browser response, - // we can safely signal completion or failure with the auth code - if (resultCode != null) - { - authCodeTcs.TrySetResult(resultCode); - } - else if (!string.IsNullOrEmpty(error)) - { - authCodeTcs.TrySetException(new McpException($"Authorization failed: {error}", McpErrorCode.InvalidRequest)); - } - else - { - authCodeTcs.TrySetException(new McpException("No authorization code received", McpErrorCode.InvalidRequest)); - } - } - catch (Exception ex) - { - authCodeTcs.TrySetException(new McpException($"Error processing browser response: {ex.Message}", McpErrorCode.InvalidRequest)); - } - } - } - catch (Exception ex) - { - authCodeTcs.TrySetException(ex); - } - }); - - // Now open the browser AFTER the listener is started - if (!string.IsNullOrEmpty(clientMetadata.ClientUri)) - { - await openBrowser(clientMetadata.ClientUri!); - } - else - { - // Stop the listener before throwing - listener.Stop(); - throw new McpException("Client URI is missing in metadata.", McpErrorCode.InvalidRequest); - } - - try - { - // Use a timeout to avoid hanging indefinitely - string authCode = await authCodeTcs.Task.WaitAsync(cts.Token); - return (redirectUri, authCode); - } - catch (OperationCanceledException) - { - throw new McpException("Authorization timed out after 5 minutes.", McpErrorCode.InvalidRequest); - } - finally - { - // Ensure the listener is stopped when we're done - listener.Stop(); - } - }; - } - /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs index cb1797b3..e430b907 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs @@ -73,14 +73,13 @@ public required Uri Endpoint /// Example: /// /// var transportOptions = new SseClientTransportOptions - /// { - /// Endpoint = new Uri("http://localhost:7071/sse"), + /// { /// Endpoint = new Uri("http://localhost:7071/sse"), /// AuthorizationOptions = new McpAuthorizationOptions /// { /// ClientId = "my-client-id", /// ClientSecret = "my-client-secret", /// RedirectUris = new[] { "http://localhost:8888/callback" }, - /// AuthorizeCallback = SseClientTransport.CreateHttpListenerAuthorizeCallback( + /// AuthorizeCallback = AuthorizationService.CreateHttpListenerAuthorizeCallback( /// openBrowser: url => Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }) /// ) /// } From 9bf4ea342bffa5da5b5b6e36b8c3230d3bc010c5 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 00:08:31 -0700 Subject: [PATCH 16/28] Amend middleware logic --- .../AuthorizationMiddleware.cs | 41 ++-------- .../HttpMcpServerBuilderExtensions.cs | 3 +- .../McpAuthorizationFilter.cs | 79 +++++++++++++++++++ .../McpAuthorizationFilterFactory.cs | 53 +++++++++++++ .../McpEndpointRouteBuilderExtensions.cs | 38 ++++++++- .../ProtectedResourceMetadataHandler.cs | 44 +++++++++++ 6 files changed, 220 insertions(+), 38 deletions(-) create mode 100644 src/ModelContextProtocol.AspNetCore/McpAuthorizationFilter.cs create mode 100644 src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs create mode 100644 src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs index 0c6ad374..d11bdf11 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs @@ -47,8 +47,9 @@ public async Task InvokeAsync( return; } - // Handle the PRM document endpoint - if (context.Request.Path.StartsWithSegments("/.well-known/oauth-protected-resource")) + // Handle the PRM document endpoint if not handled by the endpoint + if (context.Request.Path.StartsWithSegments("/.well-known/oauth-protected-resource") && + context.GetEndpoint() == null) { _logger.LogDebug("Serving Protected Resource Metadata document"); context.Response.ContentType = "application/json"; @@ -59,40 +60,8 @@ await JsonSerializer.SerializeAsync( return; } - // Serve SSE and message endpoints with authorization - if (context.Request.Path.StartsWithSegments("/sse") || - (context.Request.Path.Value?.EndsWith("/message") == true)) - { - // Check if the Authorization header is present - if (!context.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader)) - { - // No Authorization header present, return 401 Unauthorized - var prm = authProvider.GetProtectedResourceMetadata(); - var prmUrl = GetPrmUrl(context, prm.Resource); - - _logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header"); - context.Response.StatusCode = StatusCodes.Status401Unauthorized; - context.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); - return; - } - - // Validate the token - ensuring authHeader is a non-null string - string authHeaderValue = authHeader.ToString(); - bool isValid = await authProvider.ValidateTokenAsync(authHeaderValue); - if (!isValid) - { - // Invalid token, return 401 Unauthorized - var prm = authProvider.GetProtectedResourceMetadata(); - var prmUrl = GetPrmUrl(context, prm.Resource); - - _logger.LogDebug("Invalid authorization token, returning 401 Unauthorized"); - context.Response.StatusCode = StatusCodes.Status401Unauthorized; - context.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); - return; - } - } - - // Token is valid or endpoint doesn't require authentication, proceed to the next middleware + // Proceed to the next middleware - authorization for SSE and message endpoints + // is now handled by endpoint filters await _next(context); } diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index c9a5ba87..907109e6 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -18,12 +18,13 @@ public static class HttpMcpServerBuilderExtensions /// Configures options for the Streamable HTTP transport. This allows configuring per-session /// and running logic before and after a session. /// The builder provided in . - /// is . + /// is . public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder, Action? configureOptions = null) { ArgumentNullException.ThrowIfNull(builder); builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); + builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); if (configureOptions is not null) diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilter.cs b/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilter.cs new file mode 100644 index 00000000..be5f0a72 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilter.cs @@ -0,0 +1,79 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// An endpoint filter that handles authorization for MCP endpoints. +/// +internal class McpAuthorizationFilter : IEndpointFilter +{ + private readonly ILogger _logger; + private readonly IServerAuthorizationProvider _authProvider; + + /// + /// Initializes a new instance of the class. + /// + /// The logger. + /// The authorization provider. + public McpAuthorizationFilter( + ILogger logger, + IServerAuthorizationProvider authProvider) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider)); + } + + /// + public async ValueTask InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next) + { + var httpContext = context.HttpContext; + + // Check if the Authorization header is present + if (!httpContext.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader)) + { + // No Authorization header present, return 401 Unauthorized + var prm = _authProvider.GetProtectedResourceMetadata(); + var prmUrl = GetPrmUrl(httpContext, prm.Resource); + + _logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header"); + httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; + httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); + return Results.Empty; + } + + // Validate the token - ensuring authHeader is a non-null string + string authHeaderValue = authHeader.ToString(); + bool isValid = await _authProvider.ValidateTokenAsync(authHeaderValue); + if (!isValid) + { + // Invalid token, return 401 Unauthorized + var prm = _authProvider.GetProtectedResourceMetadata(); + var prmUrl = GetPrmUrl(httpContext, prm.Resource); + + _logger.LogDebug("Invalid authorization token, returning 401 Unauthorized"); + httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; + httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); + return Results.Empty; + } + + // Token is valid, proceed to the next filter + return await next(context); + } + + private static string GetPrmUrl(HttpContext context, string resourceUri) + { + // Use the actual resource URI from PRM if it's an absolute URL, otherwise build the URL + if (Uri.TryCreate(resourceUri, UriKind.Absolute, out _)) + { + return $"{resourceUri.TrimEnd('/')}/.well-known/oauth-protected-resource"; + } + + // Build the URL from the current request + var request = context.Request; + var scheme = request.Scheme; + var host = request.Host.Value; + return $"{scheme}://{host}/.well-known/oauth-protected-resource"; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs b/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs new file mode 100644 index 00000000..02890f40 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs @@ -0,0 +1,53 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Factory for creating instances. +/// +internal class McpAuthorizationFilterFactory +{ + private readonly IServiceProvider _serviceProvider; + + /// + /// Initializes a new instance of the class. + /// + /// The service provider. + public McpAuthorizationFilterFactory(IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); + } + + /// + /// Creates an endpoint filter delegate for authorization. + /// + /// The endpoint filter factory context. + /// The next filter delegate in the pipeline. + /// The filter delegate. + public EndpointFilterDelegate Create(EndpointFilterFactoryContext context, EndpointFilterDelegate next) + { + // This factory creates a filter that checks if the current endpoint is an SSE or message endpoint + // and applies authorization only to those endpoints + return async invocationContext => + { + var httpContext = invocationContext.HttpContext; + var path = httpContext.Request.Path.Value?.TrimEnd('/'); + + // Only apply authorization to /sse and /message endpoints + if (path != null && (path.EndsWith("/sse") || path.EndsWith("/message"))) + { + var authProvider = _serviceProvider.GetRequiredService(); + var logger = _serviceProvider.GetRequiredService>(); + + var filter = new McpAuthorizationFilter(logger, authProvider); + return await filter.InvokeAsync(invocationContext, next); + } + + // For all other endpoints, just invoke the next filter + return await next(invocationContext); + }; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 0eefa52f..56ae7b82 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -2,7 +2,9 @@ using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Protocol.Messages; using System.Diagnostics.CodeAnalysis; @@ -20,12 +22,36 @@ public static class McpEndpointRouteBuilderExtensions ///
/// The web application to attach MCP HTTP endpoints. /// The route pattern prefix to map to. - /// Returns a builder for configuring additional endpoint conventions like authorization policies. + /// Returns a builder for configuring additional endpoint conventions like authorization policies. public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "") { var streamableHttpHandler = endpoints.ServiceProvider.GetService() ?? throw new InvalidOperationException("You must call WithHttpTransport(). Unable to find required services. Call builder.Services.AddMcpServer().WithHttpTransport() in application startup code."); + // Map the protected resource metadata endpoint if authorization is configured + var authProvider = endpoints.ServiceProvider.GetService(); + if (authProvider != null) + { + // Create and register the ProtectedResourceMetadataHandler if it's not already registered + ProtectedResourceMetadataHandler? prmHandler = null; + try + { + prmHandler = endpoints.ServiceProvider.GetService(); + } + catch + { + // Ignore - we'll create it below + } + + if (prmHandler == null) + { + var logger = endpoints.ServiceProvider.GetRequiredService>(); + prmHandler = new ProtectedResourceMetadataHandler(logger, authProvider); + } + + endpoints.MapGet("/.well-known/oauth-protected-resource", prmHandler.HandleAsync); + } + var mcpGroup = endpoints.MapGroup(pattern); var streamableHttpGroup = mcpGroup.MapGroup("") .WithDisplayName(b => $"MCP Streamable HTTP | {b.DisplayName}") @@ -44,6 +70,16 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo var sseGroup = mcpGroup.MapGroup("") .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}"); + // Apply authorization filter to SSE endpoints if authorization is configured + if (authProvider != null) + { + // Create the filter factory + var filterFactory = endpoints.ServiceProvider.GetRequiredService(); + + // Apply filter to SSE and message endpoints + sseGroup.AddEndpointFilterFactory(filterFactory.Create); + } + sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) diff --git a/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs b/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs new file mode 100644 index 00000000..ae9c87d9 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs @@ -0,0 +1,44 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Handler for the Protected Resource Metadata document endpoint. +/// +internal class ProtectedResourceMetadataHandler +{ + private readonly ILogger _logger; + private readonly IServerAuthorizationProvider _authProvider; + + /// + /// Initializes a new instance of the class. + /// + /// The logger. + /// The authorization provider. + public ProtectedResourceMetadataHandler( + ILogger logger, + IServerAuthorizationProvider authProvider) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider)); + } + + /// + /// Handles the request for the Protected Resource Metadata document. + /// + /// The HTTP context. + /// A task that represents the asynchronous operation. + public async Task HandleAsync(HttpContext context) + { + _logger.LogDebug("Serving Protected Resource Metadata document"); + context.Response.ContentType = "application/json"; + await JsonSerializer.SerializeAsync( + context.Response.Body, + _authProvider.GetProtectedResourceMetadata(), + McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))); + } +} From fd60a1c159e6d14ab1ebeb70a6c276f1dd01ad2a Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 12:38:53 -0700 Subject: [PATCH 17/28] Trim implementation --- .../HttpMcpServerBuilderExtensions.cs | 1 - .../McpAuthorizationFilterFactory.cs | 53 ----------- .../McpEndpointAuthorizationExtensions.cs | 64 ++++++++++++++ .../McpEndpointAuthorizationFilter.cs | 84 ++++++++++++++++++ .../McpEndpointRouteBuilderExtensions.cs | 27 +++--- .../Transport/SseClientSessionTransport.cs | 88 +++++-------------- 6 files changed, 182 insertions(+), 135 deletions(-) delete mode 100644 src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs create mode 100644 src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationExtensions.cs create mode 100644 src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 9f329b6f..32fc5341 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -25,7 +25,6 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.TryAddSingleton(); builder.Services.TryAddSingleton(); - builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); if (configureOptions is not null) diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs b/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs deleted file mode 100644 index 02890f40..00000000 --- a/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilterFactory.cs +++ /dev/null @@ -1,53 +0,0 @@ -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol.Auth; - -namespace ModelContextProtocol.AspNetCore; - -/// -/// Factory for creating instances. -/// -internal class McpAuthorizationFilterFactory -{ - private readonly IServiceProvider _serviceProvider; - - /// - /// Initializes a new instance of the class. - /// - /// The service provider. - public McpAuthorizationFilterFactory(IServiceProvider serviceProvider) - { - _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); - } - - /// - /// Creates an endpoint filter delegate for authorization. - /// - /// The endpoint filter factory context. - /// The next filter delegate in the pipeline. - /// The filter delegate. - public EndpointFilterDelegate Create(EndpointFilterFactoryContext context, EndpointFilterDelegate next) - { - // This factory creates a filter that checks if the current endpoint is an SSE or message endpoint - // and applies authorization only to those endpoints - return async invocationContext => - { - var httpContext = invocationContext.HttpContext; - var path = httpContext.Request.Path.Value?.TrimEnd('/'); - - // Only apply authorization to /sse and /message endpoints - if (path != null && (path.EndsWith("/sse") || path.EndsWith("/message"))) - { - var authProvider = _serviceProvider.GetRequiredService(); - var logger = _serviceProvider.GetRequiredService>(); - - var filter = new McpAuthorizationFilter(logger, authProvider); - return await filter.InvokeAsync(invocationContext, next); - } - - // For all other endpoints, just invoke the next filter - return await next(invocationContext); - }; - } -} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationExtensions.cs new file mode 100644 index 00000000..5a8090fa --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationExtensions.cs @@ -0,0 +1,64 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Provides extension methods for adding MCP authorization to endpoints. +/// +public static class McpEndpointAuthorizationExtensions +{ + /// + /// Adds MCP authorization filter to an endpoint. + /// + /// The endpoint convention builder. + /// The authorization provider. + /// The service provider. + /// The builder for chaining. + public static IEndpointConventionBuilder AddMcpAuthorization( + this IEndpointConventionBuilder builder, + IServerAuthorizationProvider authProvider, + IServiceProvider serviceProvider) + { + if (authProvider == null) + { + return builder; // No authorization needed + } + + var logger = serviceProvider.GetRequiredService>(); + var filter = new McpEndpointAuthorizationFilter(logger, authProvider); + + return builder.AddEndpointFilter(filter); + } + + /// + /// Adds MCP authorization filter to multiple endpoints. + /// + /// The collection of endpoint convention builders. + /// The authorization provider. + /// The service provider. + /// The original collection for chaining. + public static IEnumerable AddMcpAuthorization( + this IEnumerable endpoints, + IServerAuthorizationProvider authProvider, + IServiceProvider serviceProvider) + { + if (authProvider == null) + { + return endpoints; // No authorization needed + } + + var logger = serviceProvider.GetRequiredService>(); + var filter = new McpEndpointAuthorizationFilter(logger, authProvider); + + foreach (var endpoint in endpoints) + { + endpoint.AddEndpointFilter(filter); + } + + return endpoints; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs new file mode 100644 index 00000000..f0ebe5df --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs @@ -0,0 +1,84 @@ +// filepath: c:\Users\ddelimarsky\source\csharp-sdk\src\ModelContextProtocol.AspNetCore\McpEndpointAuthorizationFilter.cs +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// An endpoint filter that handles authorization for MCP endpoints using the standard ASP.NET Core endpoint filter pattern. +/// +internal class McpEndpointAuthorizationFilter : IEndpointFilter +{ + private readonly ILogger _logger; + private readonly IServerAuthorizationProvider _authProvider; + + /// + /// Initializes a new instance of the class. + /// + /// The logger. + /// The authorization provider. + public McpEndpointAuthorizationFilter(ILogger logger, IServerAuthorizationProvider authProvider) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider)); + } + + /// + public async ValueTask InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next) + { + var httpContext = context.HttpContext; + + // Check if the Authorization header is present + if (!httpContext.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader)) + { + // No Authorization header present, return 401 Unauthorized + var prm = _authProvider.GetProtectedResourceMetadata(); + var prmUrl = GetPrmUrl(httpContext, prm.Resource); + + _logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header"); + httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; + httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); + return Results.Empty; + } + + // Validate the token + string authHeaderValue = authHeader.ToString(); + bool isValid = await _authProvider.ValidateTokenAsync(authHeaderValue); + if (!isValid) + { + // Invalid token, return 401 Unauthorized + var prm = _authProvider.GetProtectedResourceMetadata(); + var prmUrl = GetPrmUrl(httpContext, prm.Resource); + + _logger.LogDebug("Invalid authorization token, returning 401 Unauthorized"); + httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; + httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); + return Results.Empty; + } + + // Token is valid, proceed to the next filter + return await next(context); + } + + /// + /// Builds the URL for the protected resource metadata endpoint. + /// + /// The HTTP context. + /// The resource URI from the protected resource metadata. + /// The full URL to the protected resource metadata endpoint. + private static string GetPrmUrl(HttpContext context, string resourceUri) + { + // Use the actual resource URI from PRM if it's an absolute URL, otherwise build the URL + if (Uri.TryCreate(resourceUri, UriKind.Absolute, out _)) + { + return $"{resourceUri.TrimEnd('/')}/.well-known/oauth-protected-resource"; + } + + // Build the URL from the current request + var request = context.Request; + var scheme = request.Scheme; + var host = request.Host.Value; + return $"{scheme}://{host}/.well-known/oauth-protected-resource"; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 56ae7b82..51df44bd 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -66,26 +66,23 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); // Map legacy HTTP with SSE endpoints. - var sseHandler = endpoints.ServiceProvider.GetRequiredService(); - var sseGroup = mcpGroup.MapGroup("") + var sseHandler = endpoints.ServiceProvider.GetRequiredService(); var sseGroup = mcpGroup.MapGroup("") .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}"); - // Apply authorization filter to SSE endpoints if authorization is configured - if (authProvider != null) - { - // Create the filter factory - var filterFactory = endpoints.ServiceProvider.GetRequiredService(); - - // Apply filter to SSE and message endpoints - sseGroup.AddEndpointFilterFactory(filterFactory.Create); - } - - sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) + // Configure SSE endpoints + var sseEndpoint = sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); - sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) + + var messageEndpoint = sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) .WithMetadata(new AcceptsMetadata(["application/json"])) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + // Apply authorization filter directly to SSE endpoints if authorization is configured + if (authProvider != null) + { + // Apply authorization to both endpoints using the extension method + new[] { sseEndpoint, messageEndpoint }.AddMcpAuthorization(authProvider, endpoints.ServiceProvider); + } return mcpGroup; } -} +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index b9e8c6c0..c083764a 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -1,11 +1,9 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; using System.Diagnostics; -using System.Net; using System.Net.Http.Headers; using System.Net.ServerSentEvents; using System.Text; @@ -26,7 +24,6 @@ internal sealed partial class SseClientSessionTransport : TransportBase private Task? _receiveTask; private readonly ILogger _logger; private readonly TaskCompletionSource _connectionEstablished; - private readonly IAuthorizationHandler _authorizationHandler; /// /// SSE transport for client endpoints. Unlike stdio it does not launch a process, but connects to an existing server. @@ -48,18 +45,6 @@ public SseClientSessionTransport(SseClientTransportOptions transportOptions, Htt _connectionCts = new CancellationTokenSource(); _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _connectionEstablished = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - // Initialize the authorization handler - if (transportOptions.AuthorizationOptions?.AuthorizationHandler != null) - { - // Use explicitly provided handler - _authorizationHandler = transportOptions.AuthorizationOptions.AuthorizationHandler; - } - else - { - // Create default handler with auth options - _authorizationHandler = new DefaultAuthorizationHandler(loggerFactory, transportOptions.AuthorizationOptions); - } } /// @@ -89,48 +74,18 @@ public override async Task SendMessageAsync( if (_messageEndpoint == null) throw new InvalidOperationException("Transport not connected"); + using var content = new StringContent( + JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), + Encoding.UTF8, + "application/json" + ); + string messageId = "(no id)"; if (message is JsonRpcMessageWithId messageWithId) { messageId = messageWithId.Id.ToString(); } - - // Send the request, handling potential auth challenges - HttpResponseMessage? response = null; - bool authRetry = false; - - do - { - authRetry = false; - - // Create a new request for each attempt - using var currentRequest = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint); - currentRequest.Content = new StringContent( - JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), - Encoding.UTF8, - "application/json" - ); - - // Add authorization headers if needed - the handler will only add headers if auth is required - await _authorizationHandler.AuthenticateRequestAsync(currentRequest).ConfigureAwait(false); - - // Copy additional headers - CopyAdditionalHeaders(currentRequest.Headers); - - // Dispose previous response before making a new request - response?.Dispose(); - - response = await _httpClient.SendAsync(currentRequest, cancellationToken).ConfigureAwait(false); - - // Handle 401 Unauthorized response - this will only execute if the server requires auth - if (response.StatusCode == HttpStatusCode.Unauthorized) - { - // Try to handle the unauthorized response - authRetry = await _authorizationHandler.HandleUnauthorizedResponseAsync( - response, _messageEndpoint).ConfigureAwait(false); - } - } while (authRetry); using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint) { @@ -139,25 +94,26 @@ public override async Task SendMessageAsync( StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders); var response = await _httpClient.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); - var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); - // Check if the message was an initialize request - if (message is JsonRpcRequest request && request.Method == RequestMethods.Initialize) - { - // If the response is not a JSON-RPC response, it is an SSE message - if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) - { - LogAcceptedPost(Name, messageId); - // The response will arrive as an SSE message - } - else - { - JsonRpcResponse initializeResponse = JsonSerializer.Deserialize(responseContent, McpJsonUtilities.JsonContext.Default.JsonRpcResponse) ?? - throw new InvalidOperationException("Failed to initialize client"); + var responseContent = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); if (string.IsNullOrEmpty(responseContent) || responseContent.Equals("accepted", StringComparison.OrdinalIgnoreCase)) { - response.Dispose(); + LogAcceptedPost(Name, messageId); + } + else + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogRejectedPostSensitive(Name, messageId, responseContent); + } + else + { + LogRejectedPost(Name, messageId); + } + + throw new InvalidOperationException("Failed to send message"); } } From f699f774b013c2b115aab9445f29acef9f8fbc7e Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 12:42:01 -0700 Subject: [PATCH 18/28] Cleanup --- .../McpAuthorizationFilter.cs | 79 ------------------- .../McpEndpointAuthorizationFilter.cs | 1 - 2 files changed, 80 deletions(-) delete mode 100644 src/ModelContextProtocol.AspNetCore/McpAuthorizationFilter.cs diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilter.cs b/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilter.cs deleted file mode 100644 index be5f0a72..00000000 --- a/src/ModelContextProtocol.AspNetCore/McpAuthorizationFilter.cs +++ /dev/null @@ -1,79 +0,0 @@ -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol.Auth; - -namespace ModelContextProtocol.AspNetCore; - -/// -/// An endpoint filter that handles authorization for MCP endpoints. -/// -internal class McpAuthorizationFilter : IEndpointFilter -{ - private readonly ILogger _logger; - private readonly IServerAuthorizationProvider _authProvider; - - /// - /// Initializes a new instance of the class. - /// - /// The logger. - /// The authorization provider. - public McpAuthorizationFilter( - ILogger logger, - IServerAuthorizationProvider authProvider) - { - _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - _authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider)); - } - - /// - public async ValueTask InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next) - { - var httpContext = context.HttpContext; - - // Check if the Authorization header is present - if (!httpContext.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader)) - { - // No Authorization header present, return 401 Unauthorized - var prm = _authProvider.GetProtectedResourceMetadata(); - var prmUrl = GetPrmUrl(httpContext, prm.Resource); - - _logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header"); - httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; - httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); - return Results.Empty; - } - - // Validate the token - ensuring authHeader is a non-null string - string authHeaderValue = authHeader.ToString(); - bool isValid = await _authProvider.ValidateTokenAsync(authHeaderValue); - if (!isValid) - { - // Invalid token, return 401 Unauthorized - var prm = _authProvider.GetProtectedResourceMetadata(); - var prmUrl = GetPrmUrl(httpContext, prm.Resource); - - _logger.LogDebug("Invalid authorization token, returning 401 Unauthorized"); - httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; - httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); - return Results.Empty; - } - - // Token is valid, proceed to the next filter - return await next(context); - } - - private static string GetPrmUrl(HttpContext context, string resourceUri) - { - // Use the actual resource URI from PRM if it's an absolute URL, otherwise build the URL - if (Uri.TryCreate(resourceUri, UriKind.Absolute, out _)) - { - return $"{resourceUri.TrimEnd('/')}/.well-known/oauth-protected-resource"; - } - - // Build the URL from the current request - var request = context.Request; - var scheme = request.Scheme; - var host = request.Host.Value; - return $"{scheme}://{host}/.well-known/oauth-protected-resource"; - } -} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs index f0ebe5df..992b0745 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs @@ -1,4 +1,3 @@ -// filepath: c:\Users\ddelimarsky\source\csharp-sdk\src\ModelContextProtocol.AspNetCore\McpEndpointAuthorizationFilter.cs using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol.Auth; From dc8f3a1fa6552bd149fb246e3fc2a3dcc19b1773 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 16:30:10 -0700 Subject: [PATCH 19/28] Bit more cleanup here. --- .../Protocol/Auth/AuthorizationContext.cs | 2 +- .../Protocol/Auth/AuthorizationService.cs | 4 +- .../Auth/ProtectedResourceMetadata.cs | 19 +-- .../Protocol/Auth/ResourceMetadata.cs | 39 ------ .../Utils/Json/McpJsonUtilities.cs | 1 - .../Auth/ProtectedResourceMetadataTests.cs | 114 ++++++++++++++++++ 6 files changed, 120 insertions(+), 59 deletions(-) delete mode 100644 src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs create mode 100644 tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs index df4c8e59..a23da864 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs @@ -8,7 +8,7 @@ internal class AuthorizationContext /// /// Gets or sets the resource metadata. /// - public ResourceMetadata? ResourceMetadata { get; set; } + public ProtectedResourceMetadata? ResourceMetadata { get; set; } /// /// Gets or sets the authorization server metadata. diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs index 8db6028a..e3c506b6 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs @@ -26,7 +26,7 @@ public class AuthorizationService /// /// The HTTP response that contains the WWW-Authenticate header. /// A that represents the asynchronous operation. The task result contains the resource metadata if available. - public static async Task GetResourceMetadataFromResponseAsync(HttpResponseMessage response) + public static async Task GetResourceMetadataFromResponseAsync(HttpResponseMessage response) { if (response.StatusCode != HttpStatusCode.Unauthorized) { @@ -71,7 +71,7 @@ public class AuthorizationService // Read as string first, then deserialize using source-generated serializer using var reader = new StreamReader(contentStream); var json = await reader.ReadToEndAsync(); - return JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.ResourceMetadata); + return JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.ProtectedResourceMetadata); } catch (Exception) { diff --git a/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs index 7194b1b0..e80d58ad 100644 --- a/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs +++ b/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs @@ -17,13 +17,13 @@ public class ProtectedResourceMetadata /// Gets or sets the resource identifier URI. /// [JsonPropertyName("resource")] - public required string Resource { get; set; } + public required Uri Resource { get; set; } /// /// Gets or sets the authorization servers that can be used for authentication. /// [JsonPropertyName("authorization_servers")] - public required string[] AuthorizationServers { get; set; } + public required Uri[] AuthorizationServers { get; set; } /// /// Gets or sets the bearer token methods supported by the resource. @@ -41,18 +41,5 @@ public class ProtectedResourceMetadata /// Gets or sets the URL to the resource documentation. /// [JsonPropertyName("resource_documentation")] - public string? ResourceDocumentation { get; set; } - - /// - /// Converts this to the internal type. - /// - /// A instance with the same values as this instance. - internal ResourceMetadata ToResourceMetadata() => new() - { - Resource = Resource, - AuthorizationServers = AuthorizationServers, - BearerMethodsSupported = BearerMethodsSupported, - ScopesSupported = ScopesSupported, - ResourceDocumentation = ResourceDocumentation - }; + public Uri? ResourceDocumentation { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs deleted file mode 100644 index a57456c3..00000000 --- a/src/ModelContextProtocol/Protocol/Auth/ResourceMetadata.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System.Text.Json.Serialization; - -namespace ModelContextProtocol.Protocol.Auth; - -/// -/// Represents the resource metadata from the WWW-Authenticate header in a 401 Unauthorized response. -/// -public class ResourceMetadata -{ - /// - /// Gets or sets the resource identifier URI. - /// - [JsonPropertyName("resource")] - public required string Resource { get; set; } - - /// - /// Gets or sets the authorization servers that can be used for authentication. - /// - [JsonPropertyName("authorization_servers")] - public required string[] AuthorizationServers { get; set; } - - /// - /// Gets or sets the bearer token methods supported by the resource. - /// - [JsonPropertyName("bearer_methods_supported")] - public string[]? BearerMethodsSupported { get; set; } - - /// - /// Gets or sets the scopes supported by the resource. - /// - [JsonPropertyName("scopes_supported")] - public string[]? ScopesSupported { get; set; } - - /// - /// Gets or sets the URL to the resource documentation. - /// - [JsonPropertyName("resource_documentation")] - public string? ResourceDocumentation { get; set; } -} \ No newline at end of file diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index 169b27e3..30226f01 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -124,7 +124,6 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(IReadOnlyDictionary))] // Authorization-related types - [JsonSerializable(typeof(Protocol.Auth.ResourceMetadata))] [JsonSerializable(typeof(Protocol.Auth.ProtectedResourceMetadata))] [JsonSerializable(typeof(Protocol.Auth.AuthorizationServerMetadata))] [JsonSerializable(typeof(Protocol.Auth.ClientMetadata))] diff --git a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs new file mode 100644 index 00000000..68df6940 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs @@ -0,0 +1,114 @@ +// filepath: c:\Users\ddelimarsky\source\csharp-sdk\tests\ModelContextProtocol.Tests\Protocol\Auth\ProtectedResourceMetadataTests.cs +using ModelContextProtocol.Protocol.Auth; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Protocol.Auth; + +public class ProtectedResourceMetadataTests +{ + [Fact] + public void ProtectedResourceMetadata_JsonSerialization_Works() + { + // Arrange + var metadata = new ProtectedResourceMetadata + { + Resource = new Uri("http://localhost:7071"), + AuthorizationServers = [new Uri("https://login.microsoftonline.com/tenant/v2.0")], + BearerMethodsSupported = ["header"], + ScopesSupported = ["mcp.tools", "mcp.prompts"], + ResourceDocumentation = new Uri("https://example.com/docs") + }; + + // Act + var json = JsonSerializer.Serialize(metadata); + var deserialized = JsonSerializer.Deserialize(json); + + // Assert + Assert.NotNull(deserialized); + Assert.Equal("http://localhost:7071", deserialized.Resource.ToString()); + Assert.Equal("https://login.microsoftonline.com/tenant/v2.0", deserialized.AuthorizationServers[0].ToString()); + Assert.Equal("header", deserialized.BearerMethodsSupported![0]); + Assert.Equal(2, deserialized.ScopesSupported!.Length); + Assert.Contains("mcp.tools", deserialized.ScopesSupported!); + Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); + Assert.Equal("https://example.com/docs", deserialized.ResourceDocumentation!.ToString()); + } + + [Fact] + public void ProtectedResourceMetadata_JsonDeserialization_WorksWithStringProperties() + { + // Arrange + var json = @"{ + ""resource"": ""http://localhost:7071"", + ""authorization_servers"": [""https://login.microsoftonline.com/tenant/v2.0""], + ""bearer_methods_supported"": [""header""], + ""scopes_supported"": [""mcp.tools"", ""mcp.prompts""], + ""resource_documentation"": ""https://example.com/docs"" + }"; + + // Act + var deserialized = JsonSerializer.Deserialize(json); + + // Assert + Assert.NotNull(deserialized); + Assert.Equal("http://localhost:7071", deserialized.Resource.ToString()); + Assert.Equal("https://login.microsoftonline.com/tenant/v2.0", deserialized.AuthorizationServers[0].ToString()); + Assert.Equal("header", deserialized.BearerMethodsSupported![0]); + Assert.Equal(2, deserialized.ScopesSupported!.Length); + Assert.Contains("mcp.tools", deserialized.ScopesSupported!); + Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); + Assert.Equal("https://example.com/docs", deserialized.ResourceDocumentation!.ToString()); + } + + [Fact] + public void ResourceMetadata_JsonSerialization_Works() + { + // Arrange + var metadata = new ResourceMetadata + { + Resource = new Uri("http://localhost:7071"), + AuthorizationServers = [new Uri("https://login.microsoftonline.com/tenant/v2.0")], + BearerMethodsSupported = ["header"], + ScopesSupported = ["mcp.tools", "mcp.prompts"], + ResourceDocumentation = new Uri("https://example.com/docs") + }; + + // Act + var json = JsonSerializer.Serialize(metadata); + var deserialized = JsonSerializer.Deserialize(json); + + // Assert + Assert.NotNull(deserialized); + Assert.Equal("http://localhost:7071", deserialized.Resource.ToString()); + Assert.Equal("https://login.microsoftonline.com/tenant/v2.0", deserialized.AuthorizationServers[0].ToString()); + Assert.Equal("header", deserialized.BearerMethodsSupported![0]); + Assert.Equal(2, deserialized.ScopesSupported!.Length); + Assert.Contains("mcp.tools", deserialized.ScopesSupported!); + Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); + Assert.Equal("https://example.com/docs", deserialized.ResourceDocumentation!.ToString()); + } + + [Fact] + public void ToResourceMetadata_ConversionWorks() + { + // Arrange + var prm = new ProtectedResourceMetadata + { + Resource = new Uri("http://localhost:7071"), + AuthorizationServers = [new Uri("https://login.microsoftonline.com/tenant/v2.0")], + BearerMethodsSupported = ["header"], + ScopesSupported = ["mcp.tools", "mcp.prompts"], + ResourceDocumentation = new Uri("https://example.com/docs") + }; + + // Act + var resourceMetadata = prm.ToResourceMetadata(); + + // Assert + Assert.Equal(prm.Resource, resourceMetadata.Resource); + Assert.Equal(prm.AuthorizationServers, resourceMetadata.AuthorizationServers); + Assert.Equal(prm.BearerMethodsSupported, resourceMetadata.BearerMethodsSupported); + Assert.Equal(prm.ScopesSupported, resourceMetadata.ScopesSupported); + Assert.Equal(prm.ResourceDocumentation, resourceMetadata.ResourceDocumentation); + } +} From 7c2e1777a6a2ab382a3c2d411fb47df17f3728ca Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 16:30:44 -0700 Subject: [PATCH 20/28] Remove test that is no longer relevant --- .../Auth/ProtectedResourceMetadataTests.cs | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs index 68df6940..d711e255 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs @@ -87,28 +87,4 @@ public void ResourceMetadata_JsonSerialization_Works() Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); Assert.Equal("https://example.com/docs", deserialized.ResourceDocumentation!.ToString()); } - - [Fact] - public void ToResourceMetadata_ConversionWorks() - { - // Arrange - var prm = new ProtectedResourceMetadata - { - Resource = new Uri("http://localhost:7071"), - AuthorizationServers = [new Uri("https://login.microsoftonline.com/tenant/v2.0")], - BearerMethodsSupported = ["header"], - ScopesSupported = ["mcp.tools", "mcp.prompts"], - ResourceDocumentation = new Uri("https://example.com/docs") - }; - - // Act - var resourceMetadata = prm.ToResourceMetadata(); - - // Assert - Assert.Equal(prm.Resource, resourceMetadata.Resource); - Assert.Equal(prm.AuthorizationServers, resourceMetadata.AuthorizationServers); - Assert.Equal(prm.BearerMethodsSupported, resourceMetadata.BearerMethodsSupported); - Assert.Equal(prm.ScopesSupported, resourceMetadata.ScopesSupported); - Assert.Equal(prm.ResourceDocumentation, resourceMetadata.ResourceDocumentation); - } } From c88e473f3425a37c1747f96b7225279342e09fd5 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 17:24:51 -0700 Subject: [PATCH 21/28] Use URI properly --- samples/AuthorizationServerExample/Program.cs | 6 +-- .../McpEndpointAuthorizationFilter.cs | 23 ++------- .../Protocol/Auth/AuthorizationContext.cs | 32 ++++++------ .../Protocol/Auth/AuthorizationService.cs | 50 ++++++++++++------- .../Auth/DefaultAuthorizationHandler.cs | 38 ++++++-------- .../Auth/ProtectedResourceMetadataTests.cs | 4 +- 6 files changed, 70 insertions(+), 83 deletions(-) diff --git a/samples/AuthorizationServerExample/Program.cs b/samples/AuthorizationServerExample/Program.cs index ac375fff..e81ba0b6 100644 --- a/samples/AuthorizationServerExample/Program.cs +++ b/samples/AuthorizationServerExample/Program.cs @@ -23,11 +23,11 @@ public static async Task Main(string[] args) // This is the information that will be provided to clients when they need to authenticate var prm = new ProtectedResourceMetadata { - Resource = "http://localhost:7071", // Changed from HTTPS to HTTP for local development - AuthorizationServers = ["https://login.microsoftonline.com/a2213e1c-e51e-4304-9a0d-effe57f31655/v2.0"], // Let's use a dummy Entra ID tenant here + Resource = new Uri("http://localhost:7071"), // Changed from HTTPS to HTTP for local development + AuthorizationServers = [ new Uri("https://login.microsoftonline.com/a2213e1c-e51e-4304-9a0d-effe57f31655/v2.0")], // Let's use a dummy Entra ID tenant here BearerMethodsSupported = ["header"], // We support the Authorization header ScopesSupported = ["mcp.tools", "mcp.prompts", "mcp.resources"], // Scopes supported by this resource - ResourceDocumentation = "https://example.com/docs/mcp-server-auth" // Optional documentation URL + ResourceDocumentation = new Uri("https://example.com/docs/mcp-server-auth") // Optional documentation URL }; // 2. Define a token validator function diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs index 992b0745..4f49face 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs @@ -21,9 +21,7 @@ public McpEndpointAuthorizationFilter(ILogger logger, IServerAuthorizationProvid { _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider)); - } - - /// + } /// public async ValueTask InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next) { var httpContext = context.HttpContext; @@ -58,26 +56,15 @@ public McpEndpointAuthorizationFilter(ILogger logger, IServerAuthorizationProvid // Token is valid, proceed to the next filter return await next(context); - } - - /// + }/// /// Builds the URL for the protected resource metadata endpoint. /// /// The HTTP context. /// The resource URI from the protected resource metadata. /// The full URL to the protected resource metadata endpoint. - private static string GetPrmUrl(HttpContext context, string resourceUri) + private static string GetPrmUrl(HttpContext context, Uri resourceUri) { - // Use the actual resource URI from PRM if it's an absolute URL, otherwise build the URL - if (Uri.TryCreate(resourceUri, UriKind.Absolute, out _)) - { - return $"{resourceUri.TrimEnd('/')}/.well-known/oauth-protected-resource"; - } - - // Build the URL from the current request - var request = context.Request; - var scheme = request.Scheme; - var host = request.Host.Value; - return $"{scheme}://{host}/.well-known/oauth-protected-resource"; + // Create a new URI with the well-known path appended + return new Uri(resourceUri, ".well-known/oauth-protected-resource").ToString(); } } diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs index a23da864..f9c39811 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs @@ -62,36 +62,32 @@ internal class AuthorizationContext // Since HasValidToken checks that TokenResponse isn't null, we should never have null here, // but we'll add an explicit null check to satisfy the compiler return TokenResponse?.AccessToken; - } - + } + /// /// Gets a value indicating whether a refresh token is available for refreshing the access token. /// public bool CanRefreshToken => TokenResponse?.RefreshToken != null && ClientRegistration != null && AuthorizationServerMetadata != null; - + /// - /// Validates the URL of a resource against the resource URL from the metadata. + /// Validates a URI resource against the resource URI from the metadata. /// - /// The URL to validate. - /// True if the URLs match, otherwise false. - public bool ValidateResourceUrl(string resourceUrl) + /// The URI to validate. + /// True if the URIs match based on scheme and server components, otherwise false. + public bool ValidateResourceUrl(Uri resourceUri) { - if (ResourceMetadata == null || string.IsNullOrEmpty(ResourceMetadata.Resource)) + if (ResourceMetadata == null || ResourceMetadata.Resource == null || resourceUri == null) { return false; } - // Compare the host part (FQDN) rather than the full URL - if (Uri.TryCreate(resourceUrl, UriKind.Absolute, out Uri? resourceUri) && - Uri.TryCreate(ResourceMetadata.Resource, UriKind.Absolute, out Uri? metadataUri)) - { - // Compare only the host (domain name) - return string.Equals(resourceUri.Host, metadataUri.Host, StringComparison.OrdinalIgnoreCase); - } - - // If we can't parse both URLs, fall back to exact string comparison - return string.Equals(resourceUrl, ResourceMetadata.Resource, StringComparison.OrdinalIgnoreCase); + // Use Uri.Compare to properly compare the scheme and server components + // UriComponents.SchemeAndServer includes the scheme, host, port, and user info + // This handles edge cases like default ports, IPv6 addresses, etc. + return Uri.Compare(resourceUri, ResourceMetadata.Resource, + UriComponents.SchemeAndServer, UriFormat.UriEscaped, + StringComparison.OrdinalIgnoreCase) == 0; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs index e3c506b6..11753423 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs @@ -78,26 +78,21 @@ public class AuthorizationService // Failed to get resource metadata return null; } - } - - /// + } /// /// Discovers authorization server metadata from a well-known endpoint. /// - /// The base URL of the authorization server. + /// The base URL of the authorization server (as a string or Uri). /// A that represents the asynchronous operation. The task result contains the authorization server metadata. /// Thrown when both well-known endpoints return errors. - public static async Task DiscoverAuthorizationServerMetadataAsync(string authorizationServerUrl) + public static async Task DiscoverAuthorizationServerMetadataAsync(Uri authorizationServerUrl) { - Throw.IfNullOrWhiteSpace(authorizationServerUrl); - - // Remove trailing slash if present - if (authorizationServerUrl.EndsWith("/")) - { - authorizationServerUrl = authorizationServerUrl[..^1]; - } + Throw.IfNull(authorizationServerUrl); + // Create a base URI without any path, query, or fragment + var baseUri = new Uri($"{authorizationServerUrl.Scheme}://{authorizationServerUrl.Authority}"); + // Try OpenID Connect discovery endpoint - var openIdConfigUrl = $"{authorizationServerUrl}/.well-known/openid-configuration"; + var openIdConfigUrl = new Uri(baseUri, ".well-known/openid-configuration"); try { using var openIdResponse = await s_httpClient.GetAsync(openIdConfigUrl); @@ -124,8 +119,7 @@ public static async Task DiscoverAuthorizationServe } // Try OAuth 2.0 Authorization Server Metadata endpoint - var oauthConfigUrl = $"{authorizationServerUrl}/.well-known/oauth-authorization-server"; - try + var oauthConfigUrl = new Uri(baseUri, ".well-known/oauth-authorization-server"); try { using var oauthResponse = await s_httpClient.GetAsync(oauthConfigUrl); if (oauthResponse.IsSuccessStatusCode) @@ -148,13 +142,31 @@ public static async Task DiscoverAuthorizationServe catch (Exception ex) when (ex is not InvalidOperationException) { // Failed to get OAuth configuration - } - - throw new InvalidOperationException( - "Failed to discover authorization server metadata. " + + } throw new InvalidOperationException( + $"Failed to discover authorization server metadata for {authorizationServerUrl}. " + "Neither OpenID Connect nor OAuth 2.0 well-known endpoints are available."); } + /// + /// Discovers authorization server metadata from a well-known endpoint. + /// + /// The base URL of the authorization server as a string. + /// A that represents the asynchronous operation. The task result contains the authorization server metadata. + /// Thrown when both well-known endpoints return errors. + /// Thrown when the URL is invalid. + public static Task DiscoverAuthorizationServerMetadataAsync(string authorizationServerUrl) + { + Throw.IfNullOrWhiteSpace(authorizationServerUrl); + + // Create a Uri from the string and call the Uri-based method + if (Uri.TryCreate(authorizationServerUrl, UriKind.Absolute, out Uri? uri)) + { + return DiscoverAuthorizationServerMetadataAsync(uri); + } + + throw new ArgumentException("Invalid URI format", nameof(authorizationServerUrl)); + } + /// /// Registers a client with the authorization server. /// diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs index 326f82fb..7f3bbe78 100644 --- a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -94,36 +94,31 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp ResourceUri = serverUri.ToString() }; throw exception; - } - - // Store the resource metadata in the context before validating the resource URL + } // Store the resource metadata in the context before validating the resource URL authContext.Value.ResourceMetadata = resourceMetadata; // Validate that the resource matches the server FQDN - if (!authContext.Value.ValidateResourceUrl(serverUri.ToString())) + if (!authContext.Value.ValidateResourceUrl(serverUri)) { _logger.LogWarning("Resource URL mismatch: expected {Expected}, got {Actual}", serverUri, resourceMetadata.Resource); var exception = new AuthorizationException($"Resource URL mismatch: expected {serverUri}, got {resourceMetadata.Resource}"); - exception.ResourceUri = resourceMetadata.Resource; + exception.ResourceUri = resourceMetadata.Resource.ToString(); throw exception; } // Get the first authorization server from the metadata if (resourceMetadata.AuthorizationServers == null || resourceMetadata.AuthorizationServers.Length == 0) - { - _logger.LogWarning("No authorization servers found in resource metadata"); + { _logger.LogWarning("No authorization servers found in resource metadata"); var exception = new AuthorizationException("No authorization servers available"); - exception.ResourceUri = resourceMetadata.Resource; + exception.ResourceUri = resourceMetadata.Resource.ToString(); throw exception; } var authServerUrl = resourceMetadata.AuthorizationServers[0]; - _logger.LogDebug("Using authorization server: {AuthServerUrl}", authServerUrl); - - try + _logger.LogDebug("Using authorization server: {AuthServerUrl}", authServerUrl); try { // Discover authorization server metadata var authServerMetadata = await AuthorizationService.DiscoverAuthorizationServerMetadataAsync(authServerUrl); @@ -164,12 +159,11 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp else { _logger.LogWarning("Authorization server does not support dynamic client registration and no client ID was provided"); - - var exception = new AuthorizationException( + var exception = new AuthorizationException( "Authorization server does not support dynamic client registration and no client ID was provided. " + "Use McpAuthorizationOptions.ClientId to provide a pre-registered client ID."); - exception.ResourceUri = resourceMetadata.Resource; - exception.AuthorizationServerUri = authServerUrl; + exception.ResourceUri = resourceMetadata.Resource.ToString(); + exception.AuthorizationServerUri = authServerUrl.ToString(); throw exception; } @@ -177,12 +171,11 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp if (_authorizeCallback == null) { _logger.LogWarning("No authorization callback provided, can't proceed with OAuth flow"); - - var exception = new AuthorizationException( + var exception = new AuthorizationException( "Authentication is required but no authorization callback was provided. " + "Use McpAuthorizationOptions.AuthorizeCallback to provide a callback function."); - exception.ResourceUri = resourceMetadata.Resource; - exception.AuthorizationServerUri = authServerUrl; + exception.ResourceUri = resourceMetadata.Resource.ToString(); + exception.AuthorizationServerUri = authServerUrl.ToString(); throw exception; } @@ -229,12 +222,11 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp catch (Exception ex) when (ex is not AuthorizationException) { _logger.LogError(ex, "Failed to complete authorization flow"); - - var authException = new AuthorizationException( + var authException = new AuthorizationException( $"Failed to complete authorization flow: {ex.Message}", ex, McpErrorCode.InvalidRequest); - authException.ResourceUri = resourceMetadata.Resource; - authException.AuthorizationServerUri = authServerUrl; + authException.ResourceUri = resourceMetadata.Resource.ToString(); + authException.AuthorizationServerUri = authServerUrl.ToString(); throw authException; } diff --git a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs index d711e255..4b47d7d4 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs @@ -64,7 +64,7 @@ public void ProtectedResourceMetadata_JsonDeserialization_WorksWithStringPropert public void ResourceMetadata_JsonSerialization_Works() { // Arrange - var metadata = new ResourceMetadata + var metadata = new ProtectedResourceMetadata { Resource = new Uri("http://localhost:7071"), AuthorizationServers = [new Uri("https://login.microsoftonline.com/tenant/v2.0")], @@ -75,7 +75,7 @@ public void ResourceMetadata_JsonSerialization_Works() // Act var json = JsonSerializer.Serialize(metadata); - var deserialized = JsonSerializer.Deserialize(json); + var deserialized = JsonSerializer.Deserialize(json); // Assert Assert.NotNull(deserialized); From 4d379914ae8d84b3be18244ab09f3ca01b08e5c7 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 17:38:52 -0700 Subject: [PATCH 22/28] Functional cleanup --- .../AuthorizationMiddleware.cs | 19 +-------------- .../McpEndpointAuthorizationFilter.cs | 24 ++++--------------- .../ProtectedResourceMetadataHandler.cs | 18 ++++++++++++-- 3 files changed, 22 insertions(+), 39 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs index d11bdf11..c2d88958 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs @@ -58,25 +58,8 @@ await JsonSerializer.SerializeAsync( authProvider.GetProtectedResourceMetadata(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))); return; - } - - // Proceed to the next middleware - authorization for SSE and message endpoints + } // Proceed to the next middleware - authorization for SSE and message endpoints // is now handled by endpoint filters await _next(context); } - - private static string GetPrmUrl(HttpContext context, string resourceUri) - { - // Use the actual resource URI from PRM if it's an absolute URL, otherwise build the URL - if (Uri.TryCreate(resourceUri, UriKind.Absolute, out _)) - { - return $"{resourceUri.TrimEnd('/')}/.well-known/oauth-protected-resource"; - } - - // Build the URL from the current request - var request = context.Request; - var scheme = request.Scheme; - var host = request.Host.Value; - return $"{scheme}://{host}/.well-known/oauth-protected-resource"; - } } \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs index 4f49face..0d6e47eb 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs @@ -28,10 +28,9 @@ public McpEndpointAuthorizationFilter(ILogger logger, IServerAuthorizationProvid // Check if the Authorization header is present if (!httpContext.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader)) - { - // No Authorization header present, return 401 Unauthorized + { // No Authorization header present, return 401 Unauthorized var prm = _authProvider.GetProtectedResourceMetadata(); - var prmUrl = GetPrmUrl(httpContext, prm.Resource); + var prmUrl = ProtectedResourceMetadataHandler.GetProtectedResourceMetadataUrl(prm.Resource); _logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header"); httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; @@ -43,28 +42,15 @@ public McpEndpointAuthorizationFilter(ILogger logger, IServerAuthorizationProvid string authHeaderValue = authHeader.ToString(); bool isValid = await _authProvider.ValidateTokenAsync(authHeaderValue); if (!isValid) - { - // Invalid token, return 401 Unauthorized + { // Invalid token, return 401 Unauthorized var prm = _authProvider.GetProtectedResourceMetadata(); - var prmUrl = GetPrmUrl(httpContext, prm.Resource); + var prmUrl = ProtectedResourceMetadataHandler.GetProtectedResourceMetadataUrl(prm.Resource); _logger.LogDebug("Invalid authorization token, returning 401 Unauthorized"); httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); return Results.Empty; - } - - // Token is valid, proceed to the next filter + } // Token is valid, proceed to the next filter return await next(context); - }/// - /// Builds the URL for the protected resource metadata endpoint. - /// - /// The HTTP context. - /// The resource URI from the protected resource metadata. - /// The full URL to the protected resource metadata endpoint. - private static string GetPrmUrl(HttpContext context, Uri resourceUri) - { - // Create a new URI with the well-known path appended - return new Uri(resourceUri, ".well-known/oauth-protected-resource").ToString(); } } diff --git a/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs b/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs index ae9c87d9..5f0d2452 100644 --- a/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs @@ -33,12 +33,26 @@ public ProtectedResourceMetadataHandler( /// The HTTP context. /// A task that represents the asynchronous operation. public async Task HandleAsync(HttpContext context) - { - _logger.LogDebug("Serving Protected Resource Metadata document"); + { _logger.LogDebug("Serving Protected Resource Metadata document"); context.Response.ContentType = "application/json"; await JsonSerializer.SerializeAsync( context.Response.Body, _authProvider.GetProtectedResourceMetadata(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))); + } /// + /// Builds the URL for the protected resource metadata endpoint. + /// + /// The resource URI from the protected resource metadata. + /// The full URL to the protected resource metadata endpoint. + /// Thrown when resourceUri is null. + public static string GetProtectedResourceMetadataUrl(Uri resourceUri) + { + if (resourceUri == null) + { + throw new ArgumentNullException(nameof(resourceUri), "Resource URI must be provided to build the protected resource metadata URL."); + } + + // Create a new URI with the well-known path appended + return new Uri(resourceUri, ".well-known/oauth-protected-resource").ToString(); } } From 084590dc02c328fbeb640e5223a3a30c42a17190 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 18:00:10 -0700 Subject: [PATCH 23/28] Update ProtectedResourceMetadataTests.cs --- .../Auth/ProtectedResourceMetadataTests.cs | 46 ++++--------------- 1 file changed, 10 insertions(+), 36 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs index 4b47d7d4..76a7d8aa 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs @@ -1,6 +1,8 @@ -// filepath: c:\Users\ddelimarsky\source\csharp-sdk\tests\ModelContextProtocol.Tests\Protocol\Auth\ProtectedResourceMetadataTests.cs using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Utils.Json; +using System; using System.Text.Json; +using Xunit; namespace ModelContextProtocol.Tests.Protocol.Auth; @@ -25,15 +27,15 @@ public void ProtectedResourceMetadata_JsonSerialization_Works() // Assert Assert.NotNull(deserialized); - Assert.Equal("http://localhost:7071", deserialized.Resource.ToString()); - Assert.Equal("https://login.microsoftonline.com/tenant/v2.0", deserialized.AuthorizationServers[0].ToString()); + Assert.Equal(metadata.Resource, deserialized.Resource); + Assert.Equal(metadata.AuthorizationServers[0], deserialized.AuthorizationServers[0]); Assert.Equal("header", deserialized.BearerMethodsSupported![0]); Assert.Equal(2, deserialized.ScopesSupported!.Length); Assert.Contains("mcp.tools", deserialized.ScopesSupported!); Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); - Assert.Equal("https://example.com/docs", deserialized.ResourceDocumentation!.ToString()); + Assert.Equal(metadata.ResourceDocumentation, deserialized.ResourceDocumentation); } - + [Fact] public void ProtectedResourceMetadata_JsonDeserialization_WorksWithStringProperties() { @@ -51,40 +53,12 @@ public void ProtectedResourceMetadata_JsonDeserialization_WorksWithStringPropert // Assert Assert.NotNull(deserialized); - Assert.Equal("http://localhost:7071", deserialized.Resource.ToString()); - Assert.Equal("https://login.microsoftonline.com/tenant/v2.0", deserialized.AuthorizationServers[0].ToString()); - Assert.Equal("header", deserialized.BearerMethodsSupported![0]); - Assert.Equal(2, deserialized.ScopesSupported!.Length); - Assert.Contains("mcp.tools", deserialized.ScopesSupported!); - Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); - Assert.Equal("https://example.com/docs", deserialized.ResourceDocumentation!.ToString()); - } - - [Fact] - public void ResourceMetadata_JsonSerialization_Works() - { - // Arrange - var metadata = new ProtectedResourceMetadata - { - Resource = new Uri("http://localhost:7071"), - AuthorizationServers = [new Uri("https://login.microsoftonline.com/tenant/v2.0")], - BearerMethodsSupported = ["header"], - ScopesSupported = ["mcp.tools", "mcp.prompts"], - ResourceDocumentation = new Uri("https://example.com/docs") - }; - - // Act - var json = JsonSerializer.Serialize(metadata); - var deserialized = JsonSerializer.Deserialize(json); - - // Assert - Assert.NotNull(deserialized); - Assert.Equal("http://localhost:7071", deserialized.Resource.ToString()); - Assert.Equal("https://login.microsoftonline.com/tenant/v2.0", deserialized.AuthorizationServers[0].ToString()); + Assert.Equal(new Uri("http://localhost:7071"), deserialized.Resource); + Assert.Equal(new Uri("https://login.microsoftonline.com/tenant/v2.0"), deserialized.AuthorizationServers[0]); Assert.Equal("header", deserialized.BearerMethodsSupported![0]); Assert.Equal(2, deserialized.ScopesSupported!.Length); Assert.Contains("mcp.tools", deserialized.ScopesSupported!); Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); - Assert.Equal("https://example.com/docs", deserialized.ResourceDocumentation!.ToString()); + Assert.Equal(new Uri("https://example.com/docs"), deserialized.ResourceDocumentation); } } From f9f7c9da3a773b10f4c899c566de7e8e9c29e652 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 18:00:29 -0700 Subject: [PATCH 24/28] Update ProtectedResourceMetadataTests.cs --- .../Protocol/Auth/ProtectedResourceMetadataTests.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs index 76a7d8aa..e4af9c91 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs @@ -1,8 +1,5 @@ using ModelContextProtocol.Protocol.Auth; -using ModelContextProtocol.Utils.Json; -using System; using System.Text.Json; -using Xunit; namespace ModelContextProtocol.Tests.Protocol.Auth; From 3676c0e7c44cefbbb6fe0075d4e8a072b5ed079e Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 18:02:24 -0700 Subject: [PATCH 25/28] Update for consistency --- .../Protocol/Auth/AuthorizationContext.cs | 4 ++-- .../Protocol/Auth/AuthorizationService.cs | 12 ++++++------ ...RegistrationResponse.cs => ClientRegistration.cs} | 2 +- .../Protocol/Auth/DefaultAuthorizationHandler.cs | 2 +- .../Protocol/Auth/{TokenResponse.cs => Token.cs} | 2 +- .../Utils/Json/McpJsonUtilities.cs | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) rename src/ModelContextProtocol/Protocol/Auth/{ClientRegistrationResponse.cs => ClientRegistration.cs} (96%) rename src/ModelContextProtocol/Protocol/Auth/{TokenResponse.cs => Token.cs} (97%) diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs index f9c39811..c8e36c2f 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs @@ -18,12 +18,12 @@ internal class AuthorizationContext /// /// Gets or sets the client registration response. /// - public ClientRegistrationResponse? ClientRegistration { get; set; } + public ClientRegistration? ClientRegistration { get; set; } /// /// Gets or sets the token response. /// - public TokenResponse? TokenResponse { get; set; } + public Token? TokenResponse { get; set; } /// /// Gets or sets the code verifier for PKCE. diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs index 11753423..16ff917c 100644 --- a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs @@ -174,7 +174,7 @@ public static Task DiscoverAuthorizationServerMetad /// The client metadata for registration. /// A that represents the asynchronous operation. The task result contains the client registration response. /// Thrown when the authorization server does not support dynamic client registration. - public static async Task RegisterClientAsync( + public static async Task RegisterClientAsync( AuthorizationServerMetadata metadata, ClientMetadata clientMetadata) { @@ -198,7 +198,7 @@ public static async Task RegisterClientAsync( var contentStream = await response.Content.ReadAsStreamAsync(); using var reader = new StreamReader(contentStream); var json = await reader.ReadToEndAsync(); - var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.ClientRegistrationResponse); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.ClientRegistration); if (result == null) { @@ -288,7 +288,7 @@ public static string CreateAuthorizationUrl( /// The authorization code received from the authorization server. /// The code verifier for PKCE. /// A that represents the asynchronous operation. The task result contains the token response. - public static async Task ExchangeCodeForTokensAsync( + public static async Task ExchangeCodeForTokensAsync( AuthorizationServerMetadata metadata, string clientId, string? clientSecret, @@ -330,7 +330,7 @@ public static async Task ExchangeCodeForTokensAsync( var contentStream = await response.Content.ReadAsStreamAsync(); using var reader = new StreamReader(contentStream); var json = await reader.ReadToEndAsync(); - var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.TokenResponse); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.Token); if (result == null) { @@ -348,7 +348,7 @@ public static async Task ExchangeCodeForTokensAsync( /// The client secret. /// The refresh token. /// A that represents the asynchronous operation. The task result contains the token response. - public static async Task RefreshTokenAsync( + public static async Task RefreshTokenAsync( AuthorizationServerMetadata metadata, string clientId, string? clientSecret, @@ -384,7 +384,7 @@ public static async Task RefreshTokenAsync( var contentStream = await response.Content.ReadAsStreamAsync(); using var reader = new StreamReader(contentStream); var json = await reader.ReadToEndAsync(); - var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.TokenResponse); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.Token); if (result == null) { diff --git a/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs b/src/ModelContextProtocol/Protocol/Auth/ClientRegistration.cs similarity index 96% rename from src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs rename to src/ModelContextProtocol/Protocol/Auth/ClientRegistration.cs index d7042b3a..d4e66b98 100644 --- a/src/ModelContextProtocol/Protocol/Auth/ClientRegistrationResponse.cs +++ b/src/ModelContextProtocol/Protocol/Auth/ClientRegistration.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// /// Represents the OAuth 2.0 client registration response as defined in RFC 7591. /// -public class ClientRegistrationResponse +public class ClientRegistration { /// /// Gets or sets the OAuth 2.0 client identifier string. diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs index 7f3bbe78..c10286f3 100644 --- a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -140,7 +140,7 @@ public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage resp _logger.LogDebug("Using pre-configured client ID: {ClientId}", _clientId); // Create a client registration response to store in the context - var clientRegistration = new ClientRegistrationResponse + var clientRegistration = new ClientRegistration { ClientId = _clientId!, // Using null-forgiving operator since we've already checked it's not null ClientSecret = _clientSecret, diff --git a/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs b/src/ModelContextProtocol/Protocol/Auth/Token.cs similarity index 97% rename from src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs rename to src/ModelContextProtocol/Protocol/Auth/Token.cs index d6b33489..f9559068 100644 --- a/src/ModelContextProtocol/Protocol/Auth/TokenResponse.cs +++ b/src/ModelContextProtocol/Protocol/Auth/Token.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.Protocol.Auth; /// /// Represents the OAuth 2.0 token response as defined in RFC 6749. /// -public class TokenResponse +public class Token { /// /// Gets or sets the access token issued by the authorization server. diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index 30226f01..7b33940b 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -127,8 +127,8 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(Protocol.Auth.ProtectedResourceMetadata))] [JsonSerializable(typeof(Protocol.Auth.AuthorizationServerMetadata))] [JsonSerializable(typeof(Protocol.Auth.ClientMetadata))] - [JsonSerializable(typeof(Protocol.Auth.ClientRegistrationResponse))] - [JsonSerializable(typeof(Protocol.Auth.TokenResponse))] + [JsonSerializable(typeof(Protocol.Auth.ClientRegistration))] + [JsonSerializable(typeof(Protocol.Auth.Token))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; From 950a3c450f8fb268c8fd28c5b1eb89859fd91efb Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 18:04:45 -0700 Subject: [PATCH 26/28] Cleanup --- samples/AuthorizationServerExample/Program.cs | 2 +- ...ationProvider.cs => BasicServerAuthorizationProvider.cs} | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename src/ModelContextProtocol/Server/Auth/{SimpleServerAuthorizationProvider.cs => BasicServerAuthorizationProvider.cs} (88%) diff --git a/samples/AuthorizationServerExample/Program.cs b/samples/AuthorizationServerExample/Program.cs index e81ba0b6..921f8355 100644 --- a/samples/AuthorizationServerExample/Program.cs +++ b/samples/AuthorizationServerExample/Program.cs @@ -40,7 +40,7 @@ async Task ValidateToken(string token) } // 3. Create an authorization provider with the PRM and token validator - var authProvider = new SimpleServerAuthorizationProvider(prm, ValidateToken); + var authProvider = new BasicServerAuthorizationProvider(prm, ValidateToken); // 4. Configure the MCP server with authorization builder.Services.AddMcpServer(options => diff --git a/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs b/src/ModelContextProtocol/Server/Auth/BasicServerAuthorizationProvider.cs similarity index 88% rename from src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs rename to src/ModelContextProtocol/Server/Auth/BasicServerAuthorizationProvider.cs index fdd9b925..48950ea1 100644 --- a/src/ModelContextProtocol/Server/Auth/SimpleServerAuthorizationProvider.cs +++ b/src/ModelContextProtocol/Server/Auth/BasicServerAuthorizationProvider.cs @@ -3,7 +3,7 @@ namespace ModelContextProtocol.Server.Auth; /// -/// A simple implementation of that validates bearer tokens. +/// A basic implementation of . /// /// /// This implementation is intended as a starting point for server developers. In production environments, @@ -11,12 +11,12 @@ namespace ModelContextProtocol.Server.Auth; /// authentication system (e.g., OAuth 2.0 server, identity provider, etc.) /// /// -/// Initializes a new instance of the class +/// Initializes a new instance of the class /// with the specified resource metadata and token validator. /// /// The protected resource metadata. /// A function that validates access tokens. If not provided, a function that always returns true will be used. -public class SimpleServerAuthorizationProvider( +public class BasicServerAuthorizationProvider( ProtectedResourceMetadata resourceMetadata, Func>? tokenValidator = null) : IServerAuthorizationProvider { From e8b3e0d916f82c4ff10caafbeb31c91571eb2ff9 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 19:37:39 -0700 Subject: [PATCH 27/28] Cleanup --- samples/AuthorizationServerExample/Program.cs | 21 +++++------ .../AuthorizationExtensions.cs | 3 ++ .../AuthorizationMiddleware.cs | 7 ++-- .../HttpMcpServerBuilderExtensions.cs | 6 ++- .../McpAuthorizationStartupFilter.cs | 37 +++++++++++++++++++ .../McpEndpointRouteBuilderExtensions.cs | 9 +++-- .../McpServerAuthorizationExtensions.cs | 20 +++++++--- 7 files changed, 77 insertions(+), 26 deletions(-) create mode 100644 src/ModelContextProtocol.AspNetCore/McpAuthorizationStartupFilter.cs diff --git a/samples/AuthorizationServerExample/Program.cs b/samples/AuthorizationServerExample/Program.cs index 921f8355..c58e8b03 100644 --- a/samples/AuthorizationServerExample/Program.cs +++ b/samples/AuthorizationServerExample/Program.cs @@ -1,5 +1,5 @@ using ModelContextProtocol; -using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server.Auth; @@ -40,9 +40,12 @@ async Task ValidateToken(string token) } // 3. Create an authorization provider with the PRM and token validator - var authProvider = new BasicServerAuthorizationProvider(prm, ValidateToken); - - // 4. Configure the MCP server with authorization + var authProvider = new BasicServerAuthorizationProvider(prm, ValidateToken); // 4. Configure the MCP server with authorization + // WithAuthorization will automatically configure: + // - Authorization provider registration + // - Protected resource metadata endpoint (/.well-known/oauth-protected-resource) + // - Token validation middleware + // - Authorization for all MCP endpoints builder.Services.AddMcpServer(options => { options.ServerInstructions = "This is an MCP server with OAuth authorization enabled."; @@ -106,14 +109,8 @@ async Task ValidateToken(string token) var app = builder.Build(); - // 5. Enable authorization middleware (this must be before MapMcp) - // This middleware does several things: - // - Serves the PRM document at /.well-known/oauth-protected-resource - // - Checks Authorization header on requests - // - Returns 401 + WWW-Authenticate when authorization is missing or invalid - app.UseMcpAuthorization(); - - // 6. Map MCP endpoints + // 5. Map MCP endpoints + // Note: Authorization is now handled automatically by WithAuthorization() app.MapMcp(); // Configure the server URL diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs index 16f93e98..7fd62a81 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs @@ -10,6 +10,9 @@ public static class AuthorizationExtensions /// /// Adds MCP authorization middleware to the specified , which enables /// OAuth 2.0 authorization for MCP servers. + /// + /// Note: This method is called automatically when using WithAuthorization(), so you typically + /// don't need to call it directly. It's available for advanced scenarios where more control is needed. /// /// The to add the middleware to. /// A reference to this instance after the operation has completed. diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs index c2d88958..45fd9347 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs @@ -25,9 +25,7 @@ public AuthorizationMiddleware(RequestDelegate next, ILogger + } /// /// Processes a request. /// /// The HTTP context. @@ -42,6 +40,7 @@ public async Task InvokeAsync( // Check if authorization is configured if (authProvider == null) { + _logger.LogDebug("Authorization is not configured, skipping authorization middleware"); // Authorization is not configured, proceed to the next middleware await _next(context); return; @@ -54,7 +53,7 @@ public async Task InvokeAsync( _logger.LogDebug("Serving Protected Resource Metadata document"); context.Response.ContentType = "application/json"; await JsonSerializer.SerializeAsync( - context.Response.Body, + context.Response.Body, authProvider.GetProtectedResourceMetadata(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))); return; diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 32fc5341..ec74cdf0 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection.Extensions; using ModelContextProtocol.AspNetCore; using ModelContextProtocol.Server; @@ -27,6 +28,9 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); + // Add our auto-registration for the authorization middleware + builder.Services.AddTransient(); + if (configureOptions is not null) { builder.Services.Configure(configureOptions); diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationStartupFilter.cs b/src/ModelContextProtocol.AspNetCore/McpAuthorizationStartupFilter.cs new file mode 100644 index 00000000..64b70280 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpAuthorizationStartupFilter.cs @@ -0,0 +1,37 @@ +// filepath: c:\Users\ddelimarsky\source\csharp-sdk\src\ModelContextProtocol.AspNetCore\McpAuthorizationStartupFilter.cs +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Auth; +using System; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// StartupFilter that automatically adds the MCP authorization middleware when authorization is configured. +/// +internal class McpAuthorizationStartupFilter : IStartupFilter +{ + /// + /// Configures the middleware pipeline to include MCP authorization middleware when needed. + /// + /// The next configurator in the chain. + /// A new pipeline configuration action. + public Action Configure(Action next) + { + return app => + { + // Check if authorization provider is registered + bool hasAuthProvider = app.ApplicationServices.GetService() != null; + + // If authorization is configured, add the middleware + if (hasAuthProvider) + { + app.UseMcpAuthorization(); + } + + // Continue with the rest of the pipeline configuration + next(app); + }; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 51df44bd..a3cc9fb5 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -75,13 +75,14 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo var messageEndpoint = sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) .WithMetadata(new AcceptsMetadata(["application/json"])) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); - - // Apply authorization filter directly to SSE endpoints if authorization is configured + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); // Apply authorization filter directly to the endpoints if authorization is configured if (authProvider != null) { - // Apply authorization to both endpoints using the extension method + // Apply authorization to both SSE endpoints using the extension method new[] { sseEndpoint, messageEndpoint }.AddMcpAuthorization(authProvider, endpoints.ServiceProvider); + + // Apply authorization to the Streamable HTTP endpoints using the extension method + streamableHttpGroup.AddMcpAuthorization(authProvider, endpoints.ServiceProvider); } return mcpGroup; } diff --git a/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs index 7592b5e9..cb9a7435 100644 --- a/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs @@ -1,29 +1,39 @@ +using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using ModelContextProtocol.Utils; -namespace Microsoft.Extensions.DependencyInjection; +namespace ModelContextProtocol.Configuration; /// /// Extension methods for configuring authorization in MCP servers. /// public static class McpServerAuthorizationExtensions -{ +{ /// - /// Adds authorization support to the MCP server. + /// Adds authorization support to the MCP server and automatically configures the required middleware. + /// You don't need to call UseMcpAuthorization() separately - it will be handled automatically. /// /// The to configure. /// The authorization provider that will validate tokens and provide metadata. /// The so that additional calls can be chained. /// or is . + /// + /// This method automatically configures all the necessary components for authorization: + /// 1. Registers the authorization provider in the DI container + /// 2. Configures authorization middleware to serve the protected resource metadata + /// 3. Adds authorization to MCP endpoints when they are mapped + /// + /// You no longer need to call app.UseMcpAuthorization() explicitly. + /// public static IMcpServerBuilder WithAuthorization( this IMcpServerBuilder builder, IServerAuthorizationProvider authorizationProvider) { Throw.IfNull(builder); - Throw.IfNull(authorizationProvider); - + Throw.IfNull(authorizationProvider); + // Register the authorization provider in the DI container builder.Services.AddSingleton(authorizationProvider); From e80bad2db398999e5d6e4bc86aed6ab33c0759b3 Mon Sep 17 00:00:00 2001 From: "den (work)" <53200638+localden@users.noreply.github.com> Date: Thu, 1 May 2025 19:44:37 -0700 Subject: [PATCH 28/28] Update Program.cs --- samples/AuthorizationExample/Program.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs index 42b1c83c..a7c34950 100644 --- a/samples/AuthorizationExample/Program.cs +++ b/samples/AuthorizationExample/Program.cs @@ -36,7 +36,8 @@ public static async Task Main(string[] args) RedirectUris = new[] { $"http://{hostname}:{port}{callbackPath}" - }, // Configure the authorize callback with the same hostname, port, and path + }, + // Configure the authorize callback with the same hostname, port, and path AuthorizeCallback = AuthorizationService.CreateHttpListenerAuthorizeCallback( openBrowser: async (url) => {