diff --git a/.gitignore b/.gitignore index 171615f9..a9ca39b1 100644 --- a/.gitignore +++ b/.gitignore @@ -80,4 +80,7 @@ docs/api # Rider .idea/ -.idea_modules/ \ No newline at end of file +.idea_modules/ + +# Specs +.specs/ \ No newline at end of file diff --git a/ModelContextProtocol.sln b/ModelContextProtocol.sln index 0e4fd721..c033ea40 100644 --- a/ModelContextProtocol.sln +++ b/ModelContextProtocol.sln @@ -56,6 +56,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ModelContextProtocol.AspNet EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ModelContextProtocol.AspNetCore.Tests", "tests\ModelContextProtocol.AspNetCore.Tests\ModelContextProtocol.AspNetCore.Tests.csproj", "{85557BA6-3D29-4C95-A646-2A972B1C2F25}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AuthorizationExample", "samples\AuthorizationExample\AuthorizationExample.csproj", "{C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "AuthorizationServerExample", "samples\AuthorizationServerExample\AuthorizationServerExample.csproj", "{05C500AF-9CF6-C2E7-2782-95271975A5DE}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -110,6 +114,14 @@ Global {85557BA6-3D29-4C95-A646-2A972B1C2F25}.Debug|Any CPU.Build.0 = Debug|Any CPU {85557BA6-3D29-4C95-A646-2A972B1C2F25}.Release|Any CPU.ActiveCfg = Release|Any CPU {85557BA6-3D29-4C95-A646-2A972B1C2F25}.Release|Any CPU.Build.0 = Release|Any CPU + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C}.Release|Any CPU.Build.0 = Release|Any CPU + {05C500AF-9CF6-C2E7-2782-95271975A5DE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {05C500AF-9CF6-C2E7-2782-95271975A5DE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {05C500AF-9CF6-C2E7-2782-95271975A5DE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {05C500AF-9CF6-C2E7-2782-95271975A5DE}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -128,6 +140,8 @@ Global {17B8453F-AB72-99C5-E5EA-D0B065A6AE65} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} {37B6A5E0-9995-497D-8B43-3BC6870CC716} = {A2F1F52A-9107-4BF8-8C3F-2F6670E7D0AD} {85557BA6-3D29-4C95-A646-2A972B1C2F25} = {2A77AF5C-138A-4EBB-9A13-9205DCD67928} + {C2E8E0D9-5F7B-38D8-3D5D-041471BD350C} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} + {05C500AF-9CF6-C2E7-2782-95271975A5DE} = {02EA681E-C7D8-13C7-8484-4AC65E1B71E8} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {384A3888-751F-4D75-9AE5-587330582D89} diff --git a/samples/AuthorizationExample/AuthorizationExample.csproj b/samples/AuthorizationExample/AuthorizationExample.csproj new file mode 100644 index 00000000..60091057 --- /dev/null +++ b/samples/AuthorizationExample/AuthorizationExample.csproj @@ -0,0 +1,14 @@ + + + + Exe + net8.0 + enable + enable + + + + + + + \ No newline at end of file diff --git a/samples/AuthorizationExample/Program.cs b/samples/AuthorizationExample/Program.cs new file mode 100644 index 00000000..a7c34950 --- /dev/null +++ b/samples/AuthorizationExample/Program.cs @@ -0,0 +1,92 @@ +using System.Diagnostics; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Protocol.Transport; + +namespace AuthorizationExample; + +/// +/// Example demonstrating how to use the MCP C# SDK with OAuth authorization. +/// +public class Program +{ + public static async Task Main(string[] args) + { + // Define the MCP server endpoint that requires OAuth authentication + var serverEndpoint = new Uri("http://localhost:7071/sse"); + + // Configuration values for OAuth redirect + string hostname = "localhost"; + int port = 13261; + string callbackPath = "/oauth/callback/"; + + // Set up the SSE transport with authorization support + var transportOptions = new SseClientTransportOptions + { + Endpoint = serverEndpoint, + AuthorizationOptions = new AuthorizationOptions + { + // Pre-registered client credentials (if applicable) + ClientId = "04f79824-ab56-4511-a7cb-d7deaea92dc0", + + // Setting some pre-defined scopes the client requests. + Scopes = ["User.Read"], + + // Specify the exact same redirect URIs that are registered with the OAuth server + RedirectUris = new[] + { + $"http://{hostname}:{port}{callbackPath}" + }, + // Configure the authorize callback with the same hostname, port, and path + AuthorizeCallback = AuthorizationService.CreateHttpListenerAuthorizeCallback( + openBrowser: async (url) => + { + Console.WriteLine($"Opening browser to authorize at: {url}"); + Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }); + }, + hostname: hostname, + listenPort: port, + redirectPath: callbackPath + ) + } + }; + + Console.WriteLine("Connecting to MCP server..."); + + try + { + // Create the client with authorization-enabled transport + var transport = new SseClientTransport(transportOptions); + var client = await McpClientFactory.CreateAsync(transport); + + Console.WriteLine("Successfully connected and authorized!"); + + // Print the list of tools available from the server. + Console.WriteLine("\nAvailable tools:"); + foreach (var tool in await client.ListToolsAsync()) + { + Console.WriteLine($" - {tool.Name}: {tool.Description}"); + } + + // Execute a tool (this would normally be driven by LLM tool invocations). + Console.WriteLine("\nCalling 'echo' tool..."); + var result = await client.CallToolAsync( + "echo", + new Dictionary() { ["message"] = "Hello MCP!" }, + cancellationToken: CancellationToken.None); + + // echo always returns one and only one text content object + Console.WriteLine($"Tool response: {result.Content.First(c => c.Type == "text").Text}"); + } + catch (Exception ex) + { + Console.WriteLine($"Error: {ex.Message}"); + if (ex.InnerException != null) + { + Console.WriteLine($"Inner Error: {ex.InnerException.Message}"); + } + // Print the stack trace for debugging + Console.WriteLine($"Stack Trace:\n{ex.StackTrace}"); + } + } +} \ No newline at end of file diff --git a/samples/AuthorizationServerExample/AuthorizationServerExample.csproj b/samples/AuthorizationServerExample/AuthorizationServerExample.csproj new file mode 100644 index 00000000..fad460dc --- /dev/null +++ b/samples/AuthorizationServerExample/AuthorizationServerExample.csproj @@ -0,0 +1,15 @@ + + + + Exe + net8.0 + enable + enable + + + + + + + + \ No newline at end of file diff --git a/samples/AuthorizationServerExample/Program.cs b/samples/AuthorizationServerExample/Program.cs new file mode 100644 index 00000000..c58e8b03 --- /dev/null +++ b/samples/AuthorizationServerExample/Program.cs @@ -0,0 +1,132 @@ +using ModelContextProtocol; +using ModelContextProtocol.Configuration; +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server.Auth; + +namespace AuthorizationServerExample; + +/// +/// Example demonstrating how to implement authorization in an MCP server. +/// +public class Program +{ + public static async Task Main(string[] args) + { + Console.WriteLine("=== MCP Server with Authorization Support ==="); + Console.WriteLine("This example demonstrates how to implement OAuth authorization in an MCP server."); + Console.WriteLine(); + + var builder = WebApplication.CreateBuilder(args); + + // 1. Define the Protected Resource Metadata for the server + // This is the information that will be provided to clients when they need to authenticate + var prm = new ProtectedResourceMetadata + { + Resource = new Uri("http://localhost:7071"), // Changed from HTTPS to HTTP for local development + AuthorizationServers = [ new Uri("https://login.microsoftonline.com/a2213e1c-e51e-4304-9a0d-effe57f31655/v2.0")], // Let's use a dummy Entra ID tenant here + BearerMethodsSupported = ["header"], // We support the Authorization header + ScopesSupported = ["mcp.tools", "mcp.prompts", "mcp.resources"], // Scopes supported by this resource + ResourceDocumentation = new Uri("https://example.com/docs/mcp-server-auth") // Optional documentation URL + }; + + // 2. Define a token validator function + // This function receives the token from the Authorization header and should validate it + // In a real application, this would verify the token with your identity provider + async Task ValidateToken(string token) + { + // For demo purposes, we'll accept any token. + return true; + } + + // 3. Create an authorization provider with the PRM and token validator + var authProvider = new BasicServerAuthorizationProvider(prm, ValidateToken); // 4. Configure the MCP server with authorization + // WithAuthorization will automatically configure: + // - Authorization provider registration + // - Protected resource metadata endpoint (/.well-known/oauth-protected-resource) + // - Token validation middleware + // - Authorization for all MCP endpoints + builder.Services.AddMcpServer(options => + { + options.ServerInstructions = "This is an MCP server with OAuth authorization enabled."; + + // Configure regular server capabilities like tools, prompts, resources + options.Capabilities = new() + { + Tools = new() + { + // Simple Echo tool + + CallToolHandler = (request, cancellationToken) => + { + if (request.Params?.Name == "echo") + { + if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) + { + throw new McpException("Missing required argument 'message'"); + } + + return new ValueTask(new CallToolResponse() + { + Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] + }); + } + + // Protected tool that requires authorization + if (request.Params?.Name == "protected-data") + { + // This tool will only be accessible to authenticated clients + return new ValueTask(new CallToolResponse() + { + Content = [new Content() { Text = "This is protected data that only authorized clients can access" }] + }); + } + + throw new McpException($"Unknown tool: '{request.Params?.Name}'"); + }, + + ListToolsHandler = async (_, _) => new() + { + Tools = + [ + new() + { + Name = "echo", + Description = "Echoes back the message you send" + }, + new() + { + Name = "protected-data", + Description = "Returns protected data that requires authorization" + } + ] + } + } + }; + }) + .WithAuthorization(authProvider) // Enable authorization with our provider + .WithHttpTransport(); // Configure HTTP transport + + var app = builder.Build(); + + // 5. Map MCP endpoints + // Note: Authorization is now handled automatically by WithAuthorization() + app.MapMcp(); + + // Configure the server URL + app.Urls.Add("http://localhost:7071"); + + Console.WriteLine("Starting MCP server with authorization at http://localhost:7071"); + Console.WriteLine("PRM Document URL: http://localhost:7071/.well-known/oauth-protected-resource"); + + Console.WriteLine(); + Console.WriteLine("To test the server:"); + Console.WriteLine("1. Use an MCP client that supports authorization"); + Console.WriteLine("2. When prompted for authorization, enter 'valid_token' to gain access"); + Console.WriteLine("3. Any other token value will be rejected with a 401 Unauthorized"); + Console.WriteLine(); + Console.WriteLine("Press Ctrl+C to stop the server"); + + await app.RunAsync(); + } +} \ No newline at end of file diff --git a/samples/AuthorizationServerExample/Properties/launchSettings.json b/samples/AuthorizationServerExample/Properties/launchSettings.json new file mode 100644 index 00000000..35989814 --- /dev/null +++ b/samples/AuthorizationServerExample/Properties/launchSettings.json @@ -0,0 +1,12 @@ +{ + "profiles": { + "AuthorizationServerExample": { + "commandName": "Project", + "launchBrowser": true, + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + }, + "applicationUrl": "https://localhost:50481;http://localhost:50482" + } + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs new file mode 100644 index 00000000..7fd62a81 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationExtensions.cs @@ -0,0 +1,23 @@ +using Microsoft.AspNetCore.Builder; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Extension methods for using MCP authorization in ASP.NET Core applications. +/// +public static class AuthorizationExtensions +{ + /// + /// Adds MCP authorization middleware to the specified , which enables + /// OAuth 2.0 authorization for MCP servers. + /// + /// Note: This method is called automatically when using WithAuthorization(), so you typically + /// don't need to call it directly. It's available for advanced scenarios where more control is needed. + /// + /// The to add the middleware to. + /// A reference to this instance after the operation has completed. + public static IApplicationBuilder UseMcpAuthorization(this IApplicationBuilder builder) + { + return builder.UseMiddleware(); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs new file mode 100644 index 00000000..45fd9347 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationMiddleware.cs @@ -0,0 +1,64 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Middleware that handles authorization for MCP servers. +/// +internal class AuthorizationMiddleware +{ + private readonly RequestDelegate _next; + private readonly ILogger _logger; + + /// + /// Initializes a new instance of the class. + /// + /// The next middleware in the pipeline. + /// The logger factory. + public AuthorizationMiddleware(RequestDelegate next, ILogger logger) + { + _next = next ?? throw new ArgumentNullException(nameof(next)); + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + } /// + /// Processes a request. + /// + /// The HTTP context. + /// The MCP server options. + /// The authorization provider. + /// A representing the asynchronous operation. + public async Task InvokeAsync( + HttpContext context, + IOptions serverOptions, + IServerAuthorizationProvider? authProvider = null) + { + // Check if authorization is configured + if (authProvider == null) + { + _logger.LogDebug("Authorization is not configured, skipping authorization middleware"); + // Authorization is not configured, proceed to the next middleware + await _next(context); + return; + } + + // Handle the PRM document endpoint if not handled by the endpoint + if (context.Request.Path.StartsWithSegments("/.well-known/oauth-protected-resource") && + context.GetEndpoint() == null) + { + _logger.LogDebug("Serving Protected Resource Metadata document"); + context.Response.ContentType = "application/json"; + await JsonSerializer.SerializeAsync( + context.Response.Body, + authProvider.GetProtectedResourceMetadata(), + McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))); + return; + } // Proceed to the next middleware - authorization for SSE and message endpoints + // is now handled by endpoint filters + await _next(context); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 8bff4596..ec74cdf0 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection.Extensions; using ModelContextProtocol.AspNetCore; using ModelContextProtocol.Server; @@ -18,7 +19,7 @@ public static class HttpMcpServerBuilderExtensions /// Configures options for the Streamable HTTP transport. This allows configuring per-session /// and running logic before and after a session. /// The builder provided in . - /// is . + /// is . public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder, Action? configureOptions = null) { ArgumentNullException.ThrowIfNull(builder); @@ -27,6 +28,9 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.TryAddSingleton(); builder.Services.AddHostedService(); + // Add our auto-registration for the authorization middleware + builder.Services.AddTransient(); + if (configureOptions is not null) { builder.Services.Configure(configureOptions); diff --git a/src/ModelContextProtocol.AspNetCore/McpAuthorizationStartupFilter.cs b/src/ModelContextProtocol.AspNetCore/McpAuthorizationStartupFilter.cs new file mode 100644 index 00000000..64b70280 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpAuthorizationStartupFilter.cs @@ -0,0 +1,37 @@ +// filepath: c:\Users\ddelimarsky\source\csharp-sdk\src\ModelContextProtocol.AspNetCore\McpAuthorizationStartupFilter.cs +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Auth; +using System; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// StartupFilter that automatically adds the MCP authorization middleware when authorization is configured. +/// +internal class McpAuthorizationStartupFilter : IStartupFilter +{ + /// + /// Configures the middleware pipeline to include MCP authorization middleware when needed. + /// + /// The next configurator in the chain. + /// A new pipeline configuration action. + public Action Configure(Action next) + { + return app => + { + // Check if authorization provider is registered + bool hasAuthProvider = app.ApplicationServices.GetService() != null; + + // If authorization is configured, add the middleware + if (hasAuthProvider) + { + app.UseMcpAuthorization(); + } + + // Continue with the rest of the pipeline configuration + next(app); + }; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationExtensions.cs new file mode 100644 index 00000000..5a8090fa --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationExtensions.cs @@ -0,0 +1,64 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Provides extension methods for adding MCP authorization to endpoints. +/// +public static class McpEndpointAuthorizationExtensions +{ + /// + /// Adds MCP authorization filter to an endpoint. + /// + /// The endpoint convention builder. + /// The authorization provider. + /// The service provider. + /// The builder for chaining. + public static IEndpointConventionBuilder AddMcpAuthorization( + this IEndpointConventionBuilder builder, + IServerAuthorizationProvider authProvider, + IServiceProvider serviceProvider) + { + if (authProvider == null) + { + return builder; // No authorization needed + } + + var logger = serviceProvider.GetRequiredService>(); + var filter = new McpEndpointAuthorizationFilter(logger, authProvider); + + return builder.AddEndpointFilter(filter); + } + + /// + /// Adds MCP authorization filter to multiple endpoints. + /// + /// The collection of endpoint convention builders. + /// The authorization provider. + /// The service provider. + /// The original collection for chaining. + public static IEnumerable AddMcpAuthorization( + this IEnumerable endpoints, + IServerAuthorizationProvider authProvider, + IServiceProvider serviceProvider) + { + if (authProvider == null) + { + return endpoints; // No authorization needed + } + + var logger = serviceProvider.GetRequiredService>(); + var filter = new McpEndpointAuthorizationFilter(logger, authProvider); + + foreach (var endpoint in endpoints) + { + endpoint.AddEndpointFilter(filter); + } + + return endpoints; + } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs new file mode 100644 index 00000000..0d6e47eb --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointAuthorizationFilter.cs @@ -0,0 +1,56 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// An endpoint filter that handles authorization for MCP endpoints using the standard ASP.NET Core endpoint filter pattern. +/// +internal class McpEndpointAuthorizationFilter : IEndpointFilter +{ + private readonly ILogger _logger; + private readonly IServerAuthorizationProvider _authProvider; + + /// + /// Initializes a new instance of the class. + /// + /// The logger. + /// The authorization provider. + public McpEndpointAuthorizationFilter(ILogger logger, IServerAuthorizationProvider authProvider) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider)); + } /// + public async ValueTask InvokeAsync(EndpointFilterInvocationContext context, EndpointFilterDelegate next) + { + var httpContext = context.HttpContext; + + // Check if the Authorization header is present + if (!httpContext.Request.Headers.TryGetValue("Authorization", out var authHeader) || string.IsNullOrEmpty(authHeader)) + { // No Authorization header present, return 401 Unauthorized + var prm = _authProvider.GetProtectedResourceMetadata(); + var prmUrl = ProtectedResourceMetadataHandler.GetProtectedResourceMetadataUrl(prm.Resource); + + _logger.LogDebug("Authorization required, returning 401 Unauthorized with WWW-Authenticate header"); + httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; + httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); + return Results.Empty; + } + + // Validate the token + string authHeaderValue = authHeader.ToString(); + bool isValid = await _authProvider.ValidateTokenAsync(authHeaderValue); + if (!isValid) + { // Invalid token, return 401 Unauthorized + var prm = _authProvider.GetProtectedResourceMetadata(); + var prmUrl = ProtectedResourceMetadataHandler.GetProtectedResourceMetadataUrl(prm.Resource); + + _logger.LogDebug("Invalid authorization token, returning 401 Unauthorized"); + httpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; + httpContext.Response.Headers.Append("WWW-Authenticate", $"Bearer resource_metadata=\"{prmUrl}\""); + return Results.Empty; + } // Token is valid, proceed to the next filter + return await next(context); + } +} diff --git a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs index 0eefa52f..a3cc9fb5 100644 --- a/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs @@ -2,7 +2,9 @@ using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Protocol.Messages; using System.Diagnostics.CodeAnalysis; @@ -20,12 +22,36 @@ public static class McpEndpointRouteBuilderExtensions /// /// The web application to attach MCP HTTP endpoints. /// The route pattern prefix to map to. - /// Returns a builder for configuring additional endpoint conventions like authorization policies. + /// Returns a builder for configuring additional endpoint conventions like authorization policies. public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern = "") { var streamableHttpHandler = endpoints.ServiceProvider.GetService() ?? throw new InvalidOperationException("You must call WithHttpTransport(). Unable to find required services. Call builder.Services.AddMcpServer().WithHttpTransport() in application startup code."); + // Map the protected resource metadata endpoint if authorization is configured + var authProvider = endpoints.ServiceProvider.GetService(); + if (authProvider != null) + { + // Create and register the ProtectedResourceMetadataHandler if it's not already registered + ProtectedResourceMetadataHandler? prmHandler = null; + try + { + prmHandler = endpoints.ServiceProvider.GetService(); + } + catch + { + // Ignore - we'll create it below + } + + if (prmHandler == null) + { + var logger = endpoints.ServiceProvider.GetRequiredService>(); + prmHandler = new ProtectedResourceMetadataHandler(logger, authProvider); + } + + endpoints.MapGet("/.well-known/oauth-protected-resource", prmHandler.HandleAsync); + } + var mcpGroup = endpoints.MapGroup(pattern); var streamableHttpGroup = mcpGroup.MapGroup("") .WithDisplayName(b => $"MCP Streamable HTTP | {b.DisplayName}") @@ -40,16 +66,24 @@ public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpo streamableHttpGroup.MapDelete("", streamableHttpHandler.HandleDeleteRequestAsync); // Map legacy HTTP with SSE endpoints. - var sseHandler = endpoints.ServiceProvider.GetRequiredService(); - var sseGroup = mcpGroup.MapGroup("") + var sseHandler = endpoints.ServiceProvider.GetRequiredService(); var sseGroup = mcpGroup.MapGroup("") .WithDisplayName(b => $"MCP HTTP with SSE | {b.DisplayName}"); - sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) + // Configure SSE endpoints + var sseEndpoint = sseGroup.MapGet("/sse", sseHandler.HandleSseRequestAsync) .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status200OK, contentTypes: ["text/event-stream"])); - sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) - .WithMetadata(new AcceptsMetadata(["application/json"])) - .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); + var messageEndpoint = sseGroup.MapPost("/message", sseHandler.HandleMessageRequestAsync) + .WithMetadata(new AcceptsMetadata(["application/json"])) + .WithMetadata(new ProducesResponseTypeMetadata(StatusCodes.Status202Accepted)); // Apply authorization filter directly to the endpoints if authorization is configured + if (authProvider != null) + { + // Apply authorization to both SSE endpoints using the extension method + new[] { sseEndpoint, messageEndpoint }.AddMcpAuthorization(authProvider, endpoints.ServiceProvider); + + // Apply authorization to the Streamable HTTP endpoints using the extension method + streamableHttpGroup.AddMcpAuthorization(authProvider, endpoints.ServiceProvider); + } return mcpGroup; } -} +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs b/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs new file mode 100644 index 00000000..5f0d2452 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/ProtectedResourceMetadataHandler.cs @@ -0,0 +1,58 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Utils.Json; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Handler for the Protected Resource Metadata document endpoint. +/// +internal class ProtectedResourceMetadataHandler +{ + private readonly ILogger _logger; + private readonly IServerAuthorizationProvider _authProvider; + + /// + /// Initializes a new instance of the class. + /// + /// The logger. + /// The authorization provider. + public ProtectedResourceMetadataHandler( + ILogger logger, + IServerAuthorizationProvider authProvider) + { + _logger = logger ?? throw new ArgumentNullException(nameof(logger)); + _authProvider = authProvider ?? throw new ArgumentNullException(nameof(authProvider)); + } + + /// + /// Handles the request for the Protected Resource Metadata document. + /// + /// The HTTP context. + /// A task that represents the asynchronous operation. + public async Task HandleAsync(HttpContext context) + { _logger.LogDebug("Serving Protected Resource Metadata document"); + context.Response.ContentType = "application/json"; + await JsonSerializer.SerializeAsync( + context.Response.Body, + _authProvider.GetProtectedResourceMetadata(), + McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ProtectedResourceMetadata))); + } /// + /// Builds the URL for the protected resource metadata endpoint. + /// + /// The resource URI from the protected resource metadata. + /// The full URL to the protected resource metadata endpoint. + /// Thrown when resourceUri is null. + public static string GetProtectedResourceMetadataUrl(Uri resourceUri) + { + if (resourceUri == null) + { + throw new ArgumentNullException(nameof(resourceUri), "Resource URI must be provided to build the protected resource metadata URL."); + } + + // Create a new URI with the well-known path appended + return new Uri(resourceUri, ".well-known/oauth-protected-resource").ToString(); + } +} diff --git a/src/ModelContextProtocol/AuthorizationException.cs b/src/ModelContextProtocol/AuthorizationException.cs new file mode 100644 index 00000000..893b5ef8 --- /dev/null +++ b/src/ModelContextProtocol/AuthorizationException.cs @@ -0,0 +1,69 @@ +namespace ModelContextProtocol; + +/// +/// Represents an exception that is thrown when an authorization or authentication error occurs in MCP. +/// +/// +/// This exception is thrown when the client fails to authenticate with an MCP server that requires +/// authentication, such as when the OAuth authorization flow fails or when the server rejects the provided credentials. +/// +public class AuthorizationException : McpException +{ + /// + /// Initializes a new instance of the class. + /// + public AuthorizationException() + : base("Authorization failed", McpErrorCode.InvalidRequest) + { + } + + /// + /// Initializes a new instance of the class with a specified error message. + /// + /// The message that describes the error. + public AuthorizationException(string message) + : base(message, McpErrorCode.InvalidRequest) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and a reference to the inner exception that is the cause of this exception. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + public AuthorizationException(string message, Exception? innerException) + : base(message, innerException, McpErrorCode.InvalidRequest) + { + } + + /// + /// Initializes a new instance of the class with a specified error message and error code. + /// + /// The message that describes the error. + /// The MCP error code. Should use one of the standard error codes. + public AuthorizationException(string message, McpErrorCode errorCode) + : base(message, errorCode) + { + } + + /// + /// Initializes a new instance of the class with a specified error message, inner exception, and error code. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception, or a null reference if no inner exception is specified. + /// The MCP error code. Should use one of the standard error codes. + public AuthorizationException(string message, Exception? innerException, McpErrorCode errorCode) + : base(message, innerException, errorCode) + { + } + + /// + /// Gets or sets the resource that requires authorization. + /// + public string? ResourceUri { get; set; } + + /// + /// Gets or sets the authorization server URI that should be used for authentication. + /// + public string? AuthorizationServerUri { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs new file mode 100644 index 00000000..cb9a7435 --- /dev/null +++ b/src/ModelContextProtocol/Configuration/McpServerAuthorizationExtensions.cs @@ -0,0 +1,47 @@ +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Auth; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol.Configuration; + +/// +/// Extension methods for configuring authorization in MCP servers. +/// +public static class McpServerAuthorizationExtensions +{ + /// + /// Adds authorization support to the MCP server and automatically configures the required middleware. + /// You don't need to call UseMcpAuthorization() separately - it will be handled automatically. + /// + /// The to configure. + /// The authorization provider that will validate tokens and provide metadata. + /// The so that additional calls can be chained. + /// or is . + /// + /// This method automatically configures all the necessary components for authorization: + /// 1. Registers the authorization provider in the DI container + /// 2. Configures authorization middleware to serve the protected resource metadata + /// 3. Adds authorization to MCP endpoints when they are mapped + /// + /// You no longer need to call app.UseMcpAuthorization() explicitly. + /// + public static IMcpServerBuilder WithAuthorization( + this IMcpServerBuilder builder, + IServerAuthorizationProvider authorizationProvider) + { + Throw.IfNull(builder); + Throw.IfNull(authorizationProvider); + + // Register the authorization provider in the DI container + builder.Services.AddSingleton(authorizationProvider); + + builder.Services.Configure(options => + { + options.Capabilities ??= new ServerCapabilities(); + }); + + return builder; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs new file mode 100644 index 00000000..c8e36c2f --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationContext.cs @@ -0,0 +1,93 @@ +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the context for authorization in an MCP client session. +/// +internal class AuthorizationContext +{ + /// + /// Gets or sets the resource metadata. + /// + public ProtectedResourceMetadata? ResourceMetadata { get; set; } + + /// + /// Gets or sets the authorization server metadata. + /// + public AuthorizationServerMetadata? AuthorizationServerMetadata { get; set; } + + /// + /// Gets or sets the client registration response. + /// + public ClientRegistration? ClientRegistration { get; set; } + + /// + /// Gets or sets the token response. + /// + public Token? TokenResponse { get; set; } + + /// + /// Gets or sets the code verifier for PKCE. + /// + public string? CodeVerifier { get; set; } + + /// + /// Gets or sets the redirect URI used in the authorization flow. + /// + public string? RedirectUri { get; set; } + + /// + /// Gets or sets the time when the access token was issued. + /// + public DateTimeOffset? TokenIssuedAt { get; set; } + + /// + /// Gets a value indicating whether the access token is valid. + /// + public bool HasValidToken => TokenResponse != null && + (TokenResponse.ExpiresIn == null || + TokenIssuedAt != null && + DateTimeOffset.UtcNow < TokenIssuedAt.Value.AddSeconds(TokenResponse.ExpiresIn.Value - 60)); // 1 minute buffer + + /// + /// Gets the access token for authentication. + /// + /// The access token if available, otherwise null. + public string? GetAccessToken() + { + if (!HasValidToken) + { + return null; + } + + // Since HasValidToken checks that TokenResponse isn't null, we should never have null here, + // but we'll add an explicit null check to satisfy the compiler + return TokenResponse?.AccessToken; + } + + /// + /// Gets a value indicating whether a refresh token is available for refreshing the access token. + /// + public bool CanRefreshToken => TokenResponse?.RefreshToken != null && + ClientRegistration != null && + AuthorizationServerMetadata != null; + + /// + /// Validates a URI resource against the resource URI from the metadata. + /// + /// The URI to validate. + /// True if the URIs match based on scheme and server components, otherwise false. + public bool ValidateResourceUrl(Uri resourceUri) + { + if (ResourceMetadata == null || ResourceMetadata.Resource == null || resourceUri == null) + { + return false; + } + + // Use Uri.Compare to properly compare the scheme and server components + // UriComponents.SchemeAndServer includes the scheme, host, port, and user info + // This handles edge cases like default ports, IPv6 addresses, etc. + return Uri.Compare(resourceUri, ResourceMetadata.Resource, + UriComponents.SchemeAndServer, UriFormat.UriEscaped, + StringComparison.OrdinalIgnoreCase) == 0; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs new file mode 100644 index 00000000..232de02e --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationOptions.cs @@ -0,0 +1,78 @@ +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Provides authorization options for MCP clients. +/// +public class AuthorizationOptions +{ + /// + /// Gets or sets a delegate that handles the OAuth 2.0 authorization code flow. + /// + /// + /// + /// This delegate is called when the server requires OAuth 2.0 authorization. It receives the client metadata + /// and should return the redirect URI and authorization code received from the authorization server. + /// + /// + /// If not provided, the client will not be able to authenticate with servers that require OAuth authentication. + /// + /// + public Func>? AuthorizeCallback { get; init; } + + /// + /// Gets or sets the client ID to use for OAuth authorization. + /// + /// + /// + /// If specified, this client ID will be used during the OAuth flow instead of performing dynamic client registration. + /// This is useful when connecting to servers that have pre-registered clients. + /// + /// + public string? ClientId { get; init; } + + /// + /// Gets or sets the client secret associated with the client ID. + /// + /// + /// This is only required if the client was registered as a confidential client with the authorization server. + /// Public clients don't require a client secret. + /// + public string? ClientSecret { get; init; } + + /// + /// Gets or sets the redirect URIs that can be used during the OAuth authorization flow. + /// + /// + /// + /// These URIs must match the redirect URIs registered with the authorization server for the client. + /// + /// + /// If not specified and is set, a default value of + /// "http://localhost:8888/callback" will be used. + /// + /// + public ICollection? RedirectUris { get; init; } + + /// + /// Gets or sets the scopes to request during OAuth authorization. + /// + /// + /// + /// If not specified, the scopes will be determined from the server's resource metadata. + /// + /// + public ICollection? Scopes { get; init; } + + /// + /// Gets or sets a custom authorization handler. + /// + /// + /// + /// If specified, this handler will be used to manage authorization with the server. + /// + /// + /// If not provided, a default handler will be created using the other options. + /// + /// + public IAuthorizationHandler? AuthorizationHandler { get; init; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs new file mode 100644 index 00000000..56ce385f --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationServerMetadata.cs @@ -0,0 +1,69 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents OAuth 2.0 authorization server metadata as defined in RFC 8414. +/// +public class AuthorizationServerMetadata +{ + /// + /// Gets or sets the authorization endpoint URL. + /// + [JsonPropertyName("authorization_endpoint")] + public required string AuthorizationEndpoint { get; set; } + + /// + /// Gets or sets the token endpoint URL. + /// + [JsonPropertyName("token_endpoint")] + public required string TokenEndpoint { get; set; } + + /// + /// Gets or sets the client registration endpoint URL. + /// + [JsonPropertyName("registration_endpoint")] + public string? RegistrationEndpoint { get; set; } + + /// + /// Gets or sets the token revocation endpoint URL. + /// + [JsonPropertyName("revocation_endpoint")] + public string? RevocationEndpoint { get; set; } + + /// + /// Gets or sets the response types supported by the authorization server. + /// + [JsonPropertyName("response_types_supported")] + public string[]? ResponseTypesSupported { get; set; } = ["code"]; + + /// + /// Gets or sets the grant types supported by the authorization server. + /// + [JsonPropertyName("grant_types_supported")] + public string[]? GrantTypesSupported { get; set; } = ["authorization_code", "refresh_token"]; + + /// + /// Gets or sets the token endpoint authentication methods supported by the authorization server. + /// + [JsonPropertyName("token_endpoint_auth_methods_supported")] + public string[]? TokenEndpointAuthMethodsSupported { get; set; } = ["client_secret_basic"]; + + /// + /// Gets or sets the code challenge methods supported by the authorization server. + /// + [JsonPropertyName("code_challenge_methods_supported")] + public string[]? CodeChallengeMethodsSupported { get; set; } = ["S256"]; + + /// + /// Gets or sets the issuer identifier. + /// + [JsonPropertyName("issuer")] + public string? Issuer { get; set; } + + /// + /// Gets or sets the scopes supported by the authorization server. + /// + [JsonPropertyName("scopes_supported")] + public string[]? ScopesSupported { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs new file mode 100644 index 00000000..16ff917c --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/AuthorizationService.cs @@ -0,0 +1,625 @@ +using ModelContextProtocol.Utils; +using ModelContextProtocol.Utils.Json; +using System.Net; +using System.Net.Http.Headers; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Provides OAuth 2.0 authorization services for MCP clients. +/// +public class AuthorizationService +{ + private static readonly HttpClient s_httpClient = new() + { + DefaultRequestHeaders = + { + Accept = { new MediaTypeWithQualityHeaderValue("application/json") } + } + }; + + /// + /// Gets resource metadata from a 401 Unauthorized response. + /// + /// The HTTP response that contains the WWW-Authenticate header. + /// A that represents the asynchronous operation. The task result contains the resource metadata if available. + public static async Task GetResourceMetadataFromResponseAsync(HttpResponseMessage response) + { + if (response.StatusCode != HttpStatusCode.Unauthorized) + { + return null; + } + + // Get the WWW-Authenticate header + if (!response.Headers.TryGetValues("WWW-Authenticate", out var authenticateValues)) + { + return null; + } + + // Parse the WWW-Authenticate header + string? resourceMetadataUrl = null; + foreach (var value in authenticateValues) + { + if (value.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + var parameters = ParseAuthHeaderParameters(value["Bearer ".Length..].Trim()); + + if (parameters.TryGetValue("resource_metadata", out var metadataUrl)) + { + resourceMetadataUrl = metadataUrl; + break; + } + } + } + + if (string.IsNullOrEmpty(resourceMetadataUrl)) + { + return null; + } + + // Fetch the resource metadata document + try + { + using var metadataResponse = await s_httpClient.GetAsync(resourceMetadataUrl); + metadataResponse.EnsureSuccessStatusCode(); + + var contentStream = await metadataResponse.Content.ReadAsStreamAsync(); + + // Read as string first, then deserialize using source-generated serializer + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + return JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.ProtectedResourceMetadata); + } + catch (Exception) + { + // Failed to get resource metadata + return null; + } + } /// + /// Discovers authorization server metadata from a well-known endpoint. + /// + /// The base URL of the authorization server (as a string or Uri). + /// A that represents the asynchronous operation. The task result contains the authorization server metadata. + /// Thrown when both well-known endpoints return errors. + public static async Task DiscoverAuthorizationServerMetadataAsync(Uri authorizationServerUrl) + { + Throw.IfNull(authorizationServerUrl); + + // Create a base URI without any path, query, or fragment + var baseUri = new Uri($"{authorizationServerUrl.Scheme}://{authorizationServerUrl.Authority}"); + + // Try OpenID Connect discovery endpoint + var openIdConfigUrl = new Uri(baseUri, ".well-known/openid-configuration"); + try + { + using var openIdResponse = await s_httpClient.GetAsync(openIdConfigUrl); + if (openIdResponse.IsSuccessStatusCode) + { + var contentStream = await openIdResponse.Content.ReadAsStreamAsync(); + + // Use source-generated serialization instead of dynamic deserialization + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.AuthorizationServerMetadata); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse authorization server metadata"); + } + + return result; + } + } + catch (Exception ex) when (ex is not InvalidOperationException) + { + // Failed to get OpenID configuration, try OAuth endpoint + } + + // Try OAuth 2.0 Authorization Server Metadata endpoint + var oauthConfigUrl = new Uri(baseUri, ".well-known/oauth-authorization-server"); try + { + using var oauthResponse = await s_httpClient.GetAsync(oauthConfigUrl); + if (oauthResponse.IsSuccessStatusCode) + { + var contentStream = await oauthResponse.Content.ReadAsStreamAsync(); + + // Use source-generated serialization instead of dynamic deserialization + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.AuthorizationServerMetadata); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse authorization server metadata"); + } + + return result; + } + } + catch (Exception ex) when (ex is not InvalidOperationException) + { + // Failed to get OAuth configuration + } throw new InvalidOperationException( + $"Failed to discover authorization server metadata for {authorizationServerUrl}. " + + "Neither OpenID Connect nor OAuth 2.0 well-known endpoints are available."); + } + + /// + /// Discovers authorization server metadata from a well-known endpoint. + /// + /// The base URL of the authorization server as a string. + /// A that represents the asynchronous operation. The task result contains the authorization server metadata. + /// Thrown when both well-known endpoints return errors. + /// Thrown when the URL is invalid. + public static Task DiscoverAuthorizationServerMetadataAsync(string authorizationServerUrl) + { + Throw.IfNullOrWhiteSpace(authorizationServerUrl); + + // Create a Uri from the string and call the Uri-based method + if (Uri.TryCreate(authorizationServerUrl, UriKind.Absolute, out Uri? uri)) + { + return DiscoverAuthorizationServerMetadataAsync(uri); + } + + throw new ArgumentException("Invalid URI format", nameof(authorizationServerUrl)); + } + + /// + /// Registers a client with the authorization server. + /// + /// The authorization server metadata. + /// The client metadata for registration. + /// A that represents the asynchronous operation. The task result contains the client registration response. + /// Thrown when the authorization server does not support dynamic client registration. + public static async Task RegisterClientAsync( + AuthorizationServerMetadata metadata, + ClientMetadata clientMetadata) + { + Throw.IfNull(metadata); + Throw.IfNull(clientMetadata); + + if (metadata.RegistrationEndpoint == null) + { + throw new InvalidOperationException("The authorization server does not support dynamic client registration."); + } + + var content = new StringContent( + JsonSerializer.Serialize(clientMetadata, McpJsonUtilities.JsonContext.Default.ClientMetadata), + Encoding.UTF8, + "application/json"); + + using var response = await s_httpClient.PostAsync(metadata.RegistrationEndpoint, content); + response.EnsureSuccessStatusCode(); + + // Use source-generated serialization instead of dynamic deserialization + var contentStream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.ClientRegistration); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse client registration response"); + } + + return result; + } + + /// + /// Generates a code verifier and code challenge for PKCE. + /// + /// A tuple containing the code verifier and code challenge. + public static (string CodeVerifier, string CodeChallenge) GeneratePkceValues() + { + // Generate a random code verifier + var bytes = new byte[32]; + using var rng = RandomNumberGenerator.Create(); + rng.GetBytes(bytes); + var codeVerifier = Convert.ToBase64String(bytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + + // Generate the code challenge (S256) + using var sha256 = SHA256.Create(); + var challengeBytes = sha256.ComputeHash(Encoding.UTF8.GetBytes(codeVerifier)); + var codeChallenge = Convert.ToBase64String(challengeBytes) + .TrimEnd('=') + .Replace('+', '-') + .Replace('/', '_'); + + return (codeVerifier, codeChallenge); + } + + /// + /// Creates an authorization URL for the OAuth authorization code flow with PKCE. + /// + /// The authorization server metadata. + /// The client identifier. + /// The redirect URI. + /// The code challenge for PKCE. + /// The requested scopes. + /// A value used to maintain state between the request and callback. + /// The authorization URL. + public static string CreateAuthorizationUrl( + AuthorizationServerMetadata metadata, + string clientId, + string redirectUri, + string codeChallenge, + string[]? scopes = null, + string? state = null) + { + Throw.IfNull(metadata); + Throw.IfNullOrWhiteSpace(clientId); + Throw.IfNullOrWhiteSpace(redirectUri); + Throw.IfNullOrWhiteSpace(codeChallenge); + + var queryBuilder = new StringBuilder(metadata.AuthorizationEndpoint); + queryBuilder.Append(metadata.AuthorizationEndpoint.Contains('?') ? '&' : '?'); + queryBuilder.Append("response_type=code"); + queryBuilder.Append($"&client_id={Uri.EscapeDataString(clientId)}"); + queryBuilder.Append($"&redirect_uri={Uri.EscapeDataString(redirectUri)}"); + queryBuilder.Append($"&code_challenge={Uri.EscapeDataString(codeChallenge)}"); + queryBuilder.Append("&code_challenge_method=S256"); + + if (scopes != null && scopes.Length > 0) + { + queryBuilder.Append($"&scope={Uri.EscapeDataString(string.Join(" ", scopes))}"); + } + + if (!string.IsNullOrEmpty(state)) + { + queryBuilder.Append($"&state={Uri.EscapeDataString(state)}"); + } + + return queryBuilder.ToString(); + } + + /// + /// Exchanges an authorization code for tokens. + /// + /// The authorization server metadata. + /// The client identifier. + /// The client secret. + /// The redirect URI. + /// The authorization code received from the authorization server. + /// The code verifier for PKCE. + /// A that represents the asynchronous operation. The task result contains the token response. + public static async Task ExchangeCodeForTokensAsync( + AuthorizationServerMetadata metadata, + string clientId, + string? clientSecret, + string redirectUri, + string code, + string codeVerifier) + { + Throw.IfNull(metadata); + Throw.IfNullOrWhiteSpace(clientId); + Throw.IfNullOrWhiteSpace(redirectUri); + Throw.IfNullOrWhiteSpace(code); + Throw.IfNullOrWhiteSpace(codeVerifier); + + var tokenRequestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "authorization_code", + ["code"] = code, + ["redirect_uri"] = redirectUri, + ["client_id"] = clientId, + ["code_verifier"] = codeVerifier + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, metadata.TokenEndpoint) + { + Content = tokenRequestContent + }; + + // Add client authentication if client secret is provided + if (!string.IsNullOrEmpty(clientSecret)) + { + var credentials = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{clientId}:{clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", credentials); + } + + using var response = await s_httpClient.SendAsync(request); + response.EnsureSuccessStatusCode(); + + // Use source-generated serialization instead of dynamic deserialization + var contentStream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.Token); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse token response"); + } + + return result; + } + + /// + /// Refreshes an access token using a refresh token. + /// + /// The authorization server metadata. + /// The client identifier. + /// The client secret. + /// The refresh token. + /// A that represents the asynchronous operation. The task result contains the token response. + public static async Task RefreshTokenAsync( + AuthorizationServerMetadata metadata, + string clientId, + string? clientSecret, + string refreshToken) + { + Throw.IfNull(metadata); + Throw.IfNullOrWhiteSpace(clientId); + Throw.IfNullOrWhiteSpace(refreshToken); + + var tokenRequestContent = new FormUrlEncodedContent(new Dictionary + { + ["grant_type"] = "refresh_token", + ["refresh_token"] = refreshToken, + ["client_id"] = clientId + }); + + using var request = new HttpRequestMessage(HttpMethod.Post, metadata.TokenEndpoint) + { + Content = tokenRequestContent + }; + + // Add client authentication if client secret is provided + if (!string.IsNullOrEmpty(clientSecret)) + { + var credentials = Convert.ToBase64String(Encoding.UTF8.GetBytes($"{clientId}:{clientSecret}")); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", credentials); + } + + using var response = await s_httpClient.SendAsync(request); + response.EnsureSuccessStatusCode(); + + // Use source-generated serialization instead of dynamic deserialization + var contentStream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(contentStream); + var json = await reader.ReadToEndAsync(); + var result = JsonSerializer.Deserialize(json, McpJsonUtilities.JsonContext.Default.Token); + + if (result == null) + { + throw new InvalidOperationException("Failed to parse token response"); + } + + return result; + } + + private static Dictionary ParseAuthHeaderParameters(string parameters) + { + var result = new Dictionary(StringComparer.OrdinalIgnoreCase); + + var start = 0; + while (start < parameters.Length) + { + // Find the next key=value pair + var equalPos = parameters.IndexOf('=', start); + if (equalPos == -1) + break; + + var key = parameters[start..equalPos].Trim(); + start = equalPos + 1; + + // Check if the value is quoted + if (start < parameters.Length && parameters[start] == '"') + { + start++; // Skip the opening quote + + // Find the closing quote + var endQuote = start; + while (endQuote < parameters.Length) + { + endQuote = parameters.IndexOf('"', endQuote); + if (endQuote == -1) + break; + + // Check if this is an escaped quote + if (endQuote > 0 && parameters[endQuote - 1] == '\\') + { + endQuote++; // Skip the escaped quote + continue; + } + + break; // Found a non-escaped quote + } + + if (endQuote == -1) + endQuote = parameters.Length; // No closing quote found, use the rest of the string + + var value = parameters[start..endQuote]; + result[key] = value.Replace("\\\"", "\""); // Unescape quotes + + // Move past the closing quote and any following comma + start = endQuote + 1; + var commaPos = parameters.IndexOf(',', start); + if (commaPos != -1) + start = commaPos + 1; + else + break; + } + else + { + // Unquoted value, ends at the next comma or end of string + var commaPos = parameters.IndexOf(',', start); + var value = commaPos != -1 + ? parameters[start..commaPos].Trim() + : parameters[start..].Trim(); + + result[key] = value; + + if (commaPos == -1) + break; + + start = commaPos + 1; + } } + + return result; + } + + /// + /// Creates an HTTP listener callback for handling OAuth 2.0 authorization code flow. + /// + /// A function that opens a browser with the given URL. + /// The hostname to listen on. Defaults to "localhost". + /// The port to listen on. Defaults to 8888. + /// The redirect path for the HTTP listener. Defaults to "/callback". + /// + /// A function that takes and returns a task that resolves to a tuple containing + /// the redirect URI and the authorization code. + /// + public static Func> CreateHttpListenerAuthorizeCallback( + Func openBrowser, + string hostname = "localhost", + int listenPort = 8888, + string redirectPath = "/callback") + { + return async (ClientMetadata clientMetadata) => + { + string redirectUri = $"http://{hostname}:{listenPort}{redirectPath}"; + + foreach (var uri in clientMetadata.RedirectUris) + { + if (uri.StartsWith($"http://{hostname}", StringComparison.OrdinalIgnoreCase) && + Uri.TryCreate(uri, UriKind.Absolute, out var parsedUri)) + { + redirectUri = uri; + listenPort = parsedUri.IsDefaultPort ? 80 : parsedUri.Port; + redirectPath = parsedUri.AbsolutePath; + break; + } + } + + var authCodeTcs = new TaskCompletionSource(); + // Ensure the path has a trailing slash for the HttpListener prefix + string listenerPrefix = $"http://{hostname}:{listenPort}{redirectPath}"; + if (!listenerPrefix.EndsWith("/")) + { + listenerPrefix += "/"; + } + + using var listener = new HttpListener(); + listener.Prefixes.Add(listenerPrefix); + + // Start the listener BEFORE opening the browser + try + { + listener.Start(); + } + catch (HttpListenerException ex) + { + throw new McpException($"Failed to start HTTP listener on {listenerPrefix}: {ex.Message}", McpErrorCode.InvalidRequest); + } + + // Create a cancellation token source with a timeout + using var cts = new CancellationTokenSource(TimeSpan.FromMinutes(5)); + + _ = Task.Run(async () => + { + try + { + // GetContextAsync doesn't accept a cancellation token, so we need to handle cancellation manually + var contextTask = listener.GetContextAsync(); + var completedTask = await Task.WhenAny(contextTask, Task.Delay(Timeout.Infinite, cts.Token)); + + if (completedTask == contextTask) + { + var context = await contextTask; + var request = context.Request; + var response = context.Response; + + string? code = request.QueryString["code"]; + string? error = request.QueryString["error"]; + string html; + string? resultCode = null; + + if (!string.IsNullOrEmpty(error)) + { + html = $"

Authorization Failed

Error: {WebUtility.HtmlEncode(error)}

"; + } + else if (string.IsNullOrEmpty(code)) + { + html = "

Authorization Failed

No authorization code received.

"; + } + else + { + html = "

Authorization Successful

You may now close this window.

"; + resultCode = code; + } + + try + { + // Send response to browser + byte[] buffer = Encoding.UTF8.GetBytes(html); + response.ContentType = "text/html"; + response.ContentLength64 = buffer.Length; + response.OutputStream.Write(buffer, 0, buffer.Length); + + // IMPORTANT: Explicitly close the response to ensure it's fully sent + response.Close(); + + // Now that we've finished processing the browser response, + // we can safely signal completion or failure with the auth code + if (resultCode != null) + { + authCodeTcs.TrySetResult(resultCode); + } + else if (!string.IsNullOrEmpty(error)) + { + authCodeTcs.TrySetException(new McpException($"Authorization failed: {error}", McpErrorCode.InvalidRequest)); + } + else + { + authCodeTcs.TrySetException(new McpException("No authorization code received", McpErrorCode.InvalidRequest)); + } + } + catch (Exception ex) + { + authCodeTcs.TrySetException(new McpException($"Error processing browser response: {ex.Message}", McpErrorCode.InvalidRequest)); + } + } + } + catch (Exception ex) + { + authCodeTcs.TrySetException(ex); + } + }); + + // Now open the browser AFTER the listener is started + if (!string.IsNullOrEmpty(clientMetadata.ClientUri)) + { + await openBrowser(clientMetadata.ClientUri!); + } + else + { + // Stop the listener before throwing + listener.Stop(); + throw new McpException("Client URI is missing in metadata.", McpErrorCode.InvalidRequest); + } + + try + { + // Use a timeout to avoid hanging indefinitely + string authCode = await authCodeTcs.Task.WaitAsync(cts.Token); + return (redirectUri, authCode); + } + catch (OperationCanceledException) + { + throw new McpException("Authorization timed out after 5 minutes.", McpErrorCode.InvalidRequest); + } + finally + { + // Ensure the listener is stopped when we're done + listener.Stop(); + } + }; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ClientMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/ClientMetadata.cs new file mode 100644 index 00000000..d650c5a2 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/ClientMetadata.cs @@ -0,0 +1,99 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the OAuth 2.0 client registration metadata as defined in RFC 7591. +/// +public class ClientMetadata +{ + /// + /// Gets or sets the array of redirection URI strings for use in redirect-based flows. + /// + [JsonPropertyName("redirect_uris")] + public required string[] RedirectUris { get; set; } + + /// + /// Gets or sets the requested authentication method for the token endpoint. + /// + [JsonPropertyName("token_endpoint_auth_method")] + public string? TokenEndpointAuthMethod { get; set; } = "client_secret_basic"; + + /// + /// Gets or sets the array of OAuth 2.0 grant type strings that the client can use at the token endpoint. + /// + [JsonPropertyName("grant_types")] + public string[]? GrantTypes { get; set; } = ["authorization_code", "refresh_token"]; + + /// + /// Gets or sets the array of the OAuth 2.0 response type strings that the client can use at the authorization endpoint. + /// + [JsonPropertyName("response_types")] + public string[]? ResponseTypes { get; set; } = ["code"]; + + /// + /// Gets or sets the human-readable string name of the client. + /// + [JsonPropertyName("client_name")] + public string? ClientName { get; set; } + + /// + /// Gets or sets the URL string of a web page providing information about the client. + /// + [JsonPropertyName("client_uri")] + public string? ClientUri { get; set; } + + /// + /// Gets or sets the URL string that references a logo for the client. + /// + [JsonPropertyName("logo_uri")] + public string? LogoUri { get; set; } + + /// + /// Gets or sets the string containing a space-separated list of scope values that the client can use. + /// + [JsonPropertyName("scope")] + public string? Scope { get; set; } + + /// + /// Gets or sets the array of strings representing ways to contact people responsible for this client. + /// + [JsonPropertyName("contacts")] + public string[]? Contacts { get; set; } + + /// + /// Gets or sets the URL string that points to a human-readable terms of service document. + /// + [JsonPropertyName("tos_uri")] + public string? TosUri { get; set; } + + /// + /// Gets or sets the URL string that points to a human-readable privacy policy document. + /// + [JsonPropertyName("policy_uri")] + public string? PolicyUri { get; set; } + + /// + /// Gets or sets the URL string referencing the client's JSON Web Key Set document. + /// + [JsonPropertyName("jwks_uri")] + public string? JwksUri { get; set; } + + /// + /// Gets or sets the client's JSON Web Key Set document value. + /// + [JsonPropertyName("jwks")] + public object? Jwks { get; set; } + + /// + /// Gets or sets a unique identifier string assigned by the client developer or software publisher. + /// + [JsonPropertyName("software_id")] + public string? SoftwareId { get; set; } + + /// + /// Gets or sets the version identifier string for the client software. + /// + [JsonPropertyName("software_version")] + public string? SoftwareVersion { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ClientRegistration.cs b/src/ModelContextProtocol/Protocol/Auth/ClientRegistration.cs new file mode 100644 index 00000000..d4e66b98 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/ClientRegistration.cs @@ -0,0 +1,33 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the OAuth 2.0 client registration response as defined in RFC 7591. +/// +public class ClientRegistration +{ + /// + /// Gets or sets the OAuth 2.0 client identifier string. + /// + [JsonPropertyName("client_id")] + public required string ClientId { get; set; } + + /// + /// Gets or sets the OAuth 2.0 client secret string. + /// + [JsonPropertyName("client_secret")] + public string? ClientSecret { get; set; } + + /// + /// Gets or sets the time at which the client identifier was issued. + /// + [JsonPropertyName("client_id_issued_at")] + public long? ClientIdIssuedAt { get; set; } + + /// + /// Gets or sets the time at which the client secret will expire or 0 if it will not expire. + /// + [JsonPropertyName("client_secret_expires_at")] + public long? ClientSecretExpiresAt { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs new file mode 100644 index 00000000..c10286f3 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/DefaultAuthorizationHandler.cs @@ -0,0 +1,295 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Utils; +using System.Net; +using System.Net.Http.Headers; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Provides authorization handling for MCP clients. +/// +internal class DefaultAuthorizationHandler : IAuthorizationHandler +{ + private readonly ILogger _logger; + private readonly SynchronizedValue _authContext = new(new AuthorizationContext()); + private readonly Func>? _authorizeCallback; + private readonly string? _clientId; + private readonly string? _clientSecret; + private readonly ICollection? _redirectUris; + private readonly ICollection? _scopes; + + /// + /// Initializes a new instance of the class. + /// + /// The logger factory. + /// The authorization options. + public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, AuthorizationOptions? options = null) + { + _logger = loggerFactory != null + ? loggerFactory.CreateLogger() + : NullLogger.Instance; + + if (options != null) + { + _authorizeCallback = options.AuthorizeCallback; + _clientId = options.ClientId; + _clientSecret = options.ClientSecret; + _redirectUris = options.RedirectUris; + _scopes = options.Scopes; + } + } + + /// + /// Initializes a new instance of the class. + /// + /// The logger factory. + /// A callback function that handles the authorization code flow. + public DefaultAuthorizationHandler(ILoggerFactory? loggerFactory = null, Func>? authorizeCallback = null) + : this(loggerFactory, new AuthorizationOptions { AuthorizeCallback = authorizeCallback }) + { + } + + /// + public async Task AuthenticateRequestAsync(HttpRequestMessage request) + { + // Try to get a valid token, refreshing if necessary + var token = await GetValidTokenAsync(); + + if (!string.IsNullOrEmpty(token)) + { + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token); + } + } + + /// + public async Task HandleUnauthorizedResponseAsync(HttpResponseMessage response, Uri serverUri) + { + if (response.StatusCode != HttpStatusCode.Unauthorized) + { + return false; + } + + _logger.LogDebug("Received 401 Unauthorized response from {ServerUri}", serverUri); + + using var authContext = await _authContext.LockAsync(); + + // If we already have a valid token, it might be that the token was just revoked + // or has other issues - we need to clear our state and retry the authorization flow + if (authContext.Value.HasValidToken) + { + _logger.LogWarning("Server rejected our authentication token. Clearing authentication state and reauthorizing."); + authContext.Value = new AuthorizationContext(); + } + + // Try to get resource metadata from the response + var resourceMetadata = await AuthorizationService.GetResourceMetadataFromResponseAsync(response); + if (resourceMetadata == null) + { + _logger.LogWarning("Failed to extract resource metadata from 401 response"); + + // Create a more specific exception + var exception = new AuthorizationException("Authorization required but no resource metadata available") + { + ResourceUri = serverUri.ToString() + }; + throw exception; + } // Store the resource metadata in the context before validating the resource URL + authContext.Value.ResourceMetadata = resourceMetadata; + + // Validate that the resource matches the server FQDN + if (!authContext.Value.ValidateResourceUrl(serverUri)) + { + _logger.LogWarning("Resource URL mismatch: expected {Expected}, got {Actual}", + serverUri, resourceMetadata.Resource); + + var exception = new AuthorizationException($"Resource URL mismatch: expected {serverUri}, got {resourceMetadata.Resource}"); + exception.ResourceUri = resourceMetadata.Resource.ToString(); + throw exception; + } + + // Get the first authorization server from the metadata + if (resourceMetadata.AuthorizationServers == null || resourceMetadata.AuthorizationServers.Length == 0) + { _logger.LogWarning("No authorization servers found in resource metadata"); + + var exception = new AuthorizationException("No authorization servers available"); + exception.ResourceUri = resourceMetadata.Resource.ToString(); + throw exception; + } + + var authServerUrl = resourceMetadata.AuthorizationServers[0]; + _logger.LogDebug("Using authorization server: {AuthServerUrl}", authServerUrl); try + { + // Discover authorization server metadata + var authServerMetadata = await AuthorizationService.DiscoverAuthorizationServerMetadataAsync(authServerUrl); + authContext.Value.AuthorizationServerMetadata = authServerMetadata; + _logger.LogDebug("Successfully retrieved authorization server metadata"); + + // Create client metadata + string[] redirectUris = _redirectUris?.ToArray() ?? new[] { "http://localhost:8888/callback" }; + var clientMetadata = new ClientMetadata + { + RedirectUris = redirectUris, + ClientName = "MCP C# SDK Client", + Scope = string.Join(" ", _scopes ?? resourceMetadata.ScopesSupported ?? Array.Empty()) + }; + + // Register client if needed, or use pre-configured client ID + if (!string.IsNullOrEmpty(_clientId)) + { + _logger.LogDebug("Using pre-configured client ID: {ClientId}", _clientId); + + // Create a client registration response to store in the context + var clientRegistration = new ClientRegistration + { + ClientId = _clientId!, // Using null-forgiving operator since we've already checked it's not null + ClientSecret = _clientSecret, + }; + + authContext.Value.ClientRegistration = clientRegistration; + } + else if (authServerMetadata.RegistrationEndpoint != null) + { + // Register client dynamically + _logger.LogDebug("Registering client with authorization server"); + var clientRegistration = await AuthorizationService.RegisterClientAsync(authServerMetadata, clientMetadata); + authContext.Value.ClientRegistration = clientRegistration; + _logger.LogDebug("Client registered successfully with ID: {ClientId}", clientRegistration.ClientId); + } + else + { + _logger.LogWarning("Authorization server does not support dynamic client registration and no client ID was provided"); + var exception = new AuthorizationException( + "Authorization server does not support dynamic client registration and no client ID was provided. " + + "Use McpAuthorizationOptions.ClientId to provide a pre-registered client ID."); + exception.ResourceUri = resourceMetadata.Resource.ToString(); + exception.AuthorizationServerUri = authServerUrl.ToString(); + throw exception; + } + + // If we have no way to handle user authorization, we can't proceed + if (_authorizeCallback == null) + { + _logger.LogWarning("No authorization callback provided, can't proceed with OAuth flow"); + var exception = new AuthorizationException( + "Authentication is required but no authorization callback was provided. " + + "Use McpAuthorizationOptions.AuthorizeCallback to provide a callback function."); + exception.ResourceUri = resourceMetadata.Resource.ToString(); + exception.AuthorizationServerUri = authServerUrl.ToString(); + throw exception; + } + + // Generate PKCE values + var (codeVerifier, codeChallenge) = AuthorizationService.GeneratePkceValues(); + authContext.Value.CodeVerifier = codeVerifier; + + // Initiate authorization code flow + _logger.LogDebug("Initiating authorization code flow"); + + // Get the authorization URL that the user needs to visit + var authUrl = AuthorizationService.CreateAuthorizationUrl( + authServerMetadata, + authContext.Value.ClientRegistration.ClientId, + clientMetadata.RedirectUris[0], + codeChallenge, + _scopes?.ToArray() ?? resourceMetadata.ScopesSupported); + + _logger.LogDebug("Authorization URL: {AuthUrl}", authUrl); + + // Set the authorization URL in the client metadata + clientMetadata.ClientUri = authUrl; + + // Let the callback handle the user authorization + var (redirectUri, code) = await _authorizeCallback(clientMetadata); + authContext.Value.RedirectUri = redirectUri; + + // Exchange the code for tokens + _logger.LogDebug("Exchanging authorization code for tokens"); + var tokenResponse = await AuthorizationService.ExchangeCodeForTokensAsync( + authServerMetadata, + authContext.Value.ClientRegistration.ClientId, + authContext.Value.ClientRegistration.ClientSecret, + redirectUri, + code, + codeVerifier); + + authContext.Value.TokenResponse = tokenResponse; + authContext.Value.TokenIssuedAt = DateTimeOffset.UtcNow; + + _logger.LogDebug("Successfully obtained access token"); + return true; + } + catch (Exception ex) when (ex is not AuthorizationException) + { + _logger.LogError(ex, "Failed to complete authorization flow"); + var authException = new AuthorizationException( + $"Failed to complete authorization flow: {ex.Message}", ex, McpErrorCode.InvalidRequest); + + authException.ResourceUri = resourceMetadata.Resource.ToString(); + authException.AuthorizationServerUri = authServerUrl.ToString(); + + throw authException; + } + } + + private async Task GetValidTokenAsync() + { + using var authContext = await _authContext.LockAsync(); + + // If we have a valid token, use it + if (authContext.Value.HasValidToken) + { + _logger.LogDebug("Using existing valid access token"); + return authContext.Value.GetAccessToken(); + } + + // If we can refresh the token, do so + if (authContext.Value.CanRefreshToken) + { + try + { + _logger.LogDebug("Refreshing access token"); + + // Null checks to ensure parameters are valid + if (authContext.Value.AuthorizationServerMetadata == null) + { + _logger.LogError("Cannot refresh token: AuthorizationServerMetadata is null"); + return null; + } + + if (authContext.Value.ClientRegistration == null) + { + _logger.LogError("Cannot refresh token: ClientRegistration is null"); + return null; + } + + if (authContext.Value.TokenResponse?.RefreshToken == null) + { + _logger.LogError("Cannot refresh token: RefreshToken is null"); + return null; + } + + var tokenResponse = await AuthorizationService.RefreshTokenAsync( + authContext.Value.AuthorizationServerMetadata, + authContext.Value.ClientRegistration.ClientId, + authContext.Value.ClientRegistration.ClientSecret, + authContext.Value.TokenResponse.RefreshToken); + + authContext.Value.TokenResponse = tokenResponse; + authContext.Value.TokenIssuedAt = DateTimeOffset.UtcNow; + + _logger.LogDebug("Successfully refreshed access token"); + return tokenResponse.AccessToken; + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to refresh access token"); + // Clear the token so we'll try to reauthenticate on the next request + authContext.Value.TokenResponse = null; + authContext.Value.TokenIssuedAt = null; + } + } + + return null; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs b/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs new file mode 100644 index 00000000..ffa41acb --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/IAuthorizationHandler.cs @@ -0,0 +1,22 @@ +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Defines methods for handling authorization in an MCP client. +/// +public interface IAuthorizationHandler +{ + /// + /// Handles authentication for HTTP requests. + /// + /// The HTTP request to authenticate. + /// A representing the asynchronous operation. + Task AuthenticateRequestAsync(HttpRequestMessage request); + + /// + /// Handles a 401 Unauthorized response. + /// + /// The HTTP response that contains the 401 status code. + /// The URI of the server that returned the 401 status code. + /// A that represents the asynchronous operation. The task result contains true if the authentication was successful and the request should be retried, otherwise false. + Task HandleUnauthorizedResponseAsync(HttpResponseMessage response, Uri serverUri); +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs b/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs new file mode 100644 index 00000000..4aee3f22 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/IServerAuthorizationProvider.cs @@ -0,0 +1,24 @@ +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Defines the interface for MCP server authorization providers. +/// +/// +/// This interface is implemented by authorization providers that enable MCP servers to validate tokens +/// and control access to protected resources. +/// +public interface IServerAuthorizationProvider +{ + /// + /// Gets the Protected Resource Metadata (PRM) for the server. + /// + /// The protected resource metadata. + ProtectedResourceMetadata GetProtectedResourceMetadata(); + + /// + /// Validates the provided authorization token. + /// + /// The authorization header value. + /// A representing the asynchronous validation operation. The task result contains if the token is valid; otherwise, . + Task ValidateTokenAsync(string authorizationHeader); +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs b/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs new file mode 100644 index 00000000..e80d58ad --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/ProtectedResourceMetadata.cs @@ -0,0 +1,45 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the Protected Resource Metadata (PRM) document for an OAuth 2.0 protected resource. +/// +/// +/// The PRM document describes the properties and requirements of a protected resource, including +/// the authorization servers that can be used to obtain access tokens and the scopes that are supported. +/// This document is served at the standard path "/.well-known/oauth-protected-resource" by MCP servers +/// that have authorization enabled. +/// +public class ProtectedResourceMetadata +{ + /// + /// Gets or sets the resource identifier URI. + /// + [JsonPropertyName("resource")] + public required Uri Resource { get; set; } + + /// + /// Gets or sets the authorization servers that can be used for authentication. + /// + [JsonPropertyName("authorization_servers")] + public required Uri[] AuthorizationServers { get; set; } + + /// + /// Gets or sets the bearer token methods supported by the resource. + /// + [JsonPropertyName("bearer_methods_supported")] + public string[]? BearerMethodsSupported { get; set; } = ["header"]; + + /// + /// Gets or sets the scopes supported by the resource. + /// + [JsonPropertyName("scopes_supported")] + public string[]? ScopesSupported { get; set; } + + /// + /// Gets or sets the URL to the resource documentation. + /// + [JsonPropertyName("resource_documentation")] + public Uri? ResourceDocumentation { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Auth/Token.cs b/src/ModelContextProtocol/Protocol/Auth/Token.cs new file mode 100644 index 00000000..f9559068 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Auth/Token.cs @@ -0,0 +1,39 @@ +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol.Auth; + +/// +/// Represents the OAuth 2.0 token response as defined in RFC 6749. +/// +public class Token +{ + /// + /// Gets or sets the access token issued by the authorization server. + /// + [JsonPropertyName("access_token")] + public required string AccessToken { get; set; } + + /// + /// Gets or sets the type of the token issued. + /// + [JsonPropertyName("token_type")] + public required string TokenType { get; set; } + + /// + /// Gets or sets the lifetime in seconds of the access token. + /// + [JsonPropertyName("expires_in")] + public long? ExpiresIn { get; set; } + + /// + /// Gets or sets the refresh token, which can be used to obtain new access tokens. + /// + [JsonPropertyName("refresh_token")] + public string? RefreshToken { get; set; } + + /// + /// Gets or sets the scope of the access token. + /// + [JsonPropertyName("scope")] + public string? Scope { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index 1b286557..fc8a35e5 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol.Auth; using ModelContextProtocol.Utils; namespace ModelContextProtocol.Protocol.Transport; @@ -7,10 +8,15 @@ namespace ModelContextProtocol.Protocol.Transport; /// Provides an over HTTP using the Server-Sent Events (SSE) protocol. /// /// +/// /// This transport connects to an MCP server over HTTP using SSE, /// allowing for real-time server-to-client communication with a standard HTTP request. /// Unlike the , this transport connects to an existing server /// rather than launching a new process. +/// +/// +/// The SSE transport can handle OAuth 2.0 authorization flows when connecting to servers that require authentication. +/// /// public sealed class SseClientTransport : IClientTransport, IAsyncDisposable { diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs index b83204ae..c696c664 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransportOptions.cs @@ -1,5 +1,7 @@ namespace ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Auth; + /// /// Provides options for configuring instances. /// @@ -62,4 +64,35 @@ public required Uri Endpoint /// Use this property to specify custom HTTP headers that should be sent with each request to the server. /// public Dictionary? AdditionalHeaders { get; init; } + + /// + /// Gets or sets the authorization options to use when connecting to the SSE server. + /// + /// + /// + /// These options configure the behavior of client-side authorization with the SSE server. + /// + /// + /// You can use this to specify a callback for handling the authorization code flow, + /// provide pre-registered client credentials, or configure other aspects of the OAuth flow. + /// + /// + /// Example: + /// + /// var transportOptions = new SseClientTransportOptions + /// { /// Endpoint = new Uri("http://localhost:7071/sse"), + /// AuthorizationOptions = new McpAuthorizationOptions + /// { + /// ClientId = "my-client-id", + /// ClientSecret = "my-client-secret", + /// RedirectUris = new[] { "http://localhost:8888/callback" }, + /// AuthorizeCallback = AuthorizationService.CreateHttpListenerAuthorizeCallback( + /// openBrowser: url => Process.Start(new ProcessStartInfo(url) { UseShellExecute = true }) + /// ) + /// } + /// }; + /// + /// + /// + public AuthorizationOptions? AuthorizationOptions { get; init; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs b/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs new file mode 100644 index 00000000..47557b00 --- /dev/null +++ b/src/ModelContextProtocol/Protocol/Types/AuthorizationCapability.cs @@ -0,0 +1,19 @@ +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.Protocol.Types; + +/// +/// Defines the capabilities of a server for supporting OAuth 2.0 authorization. +/// +/// +/// This capability is advertised by servers that support OAuth 2.0 authorization flows +/// and require clients to authenticate using bearer tokens. +/// +public class AuthorizationCapability +{ + /// + /// Gets or sets the authorization provider that handles token validation and provides + /// metadata about the protected resource. + /// + public IServerAuthorizationProvider? AuthorizationProvider { get; set; } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/Auth/BasicServerAuthorizationProvider.cs b/src/ModelContextProtocol/Server/Auth/BasicServerAuthorizationProvider.cs new file mode 100644 index 00000000..48950ea1 --- /dev/null +++ b/src/ModelContextProtocol/Server/Auth/BasicServerAuthorizationProvider.cs @@ -0,0 +1,47 @@ +using ModelContextProtocol.Protocol.Auth; + +namespace ModelContextProtocol.Server.Auth; + +/// +/// A basic implementation of . +/// +/// +/// This implementation is intended as a starting point for server developers. In production environments, +/// it should be extended or replaced with a more robust implementation that integrates with your +/// authentication system (e.g., OAuth 2.0 server, identity provider, etc.) +/// +/// +/// Initializes a new instance of the class +/// with the specified resource metadata and token validator. +/// +/// The protected resource metadata. +/// A function that validates access tokens. If not provided, a function that always returns true will be used. +public class BasicServerAuthorizationProvider( + ProtectedResourceMetadata resourceMetadata, + Func>? tokenValidator = null) : IServerAuthorizationProvider +{ + private readonly ProtectedResourceMetadata _resourceMetadata = resourceMetadata ?? throw new ArgumentNullException(nameof(resourceMetadata)); + private readonly Func> _tokenValidator = tokenValidator ?? (_ => Task.FromResult(true)); + + /// + public ProtectedResourceMetadata GetProtectedResourceMetadata() => _resourceMetadata; + + /// + public async Task ValidateTokenAsync(string authorizationHeader) + { + // Extract the token from the Authorization header + if (string.IsNullOrEmpty(authorizationHeader) || !authorizationHeader.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + return false; + } + + var token = authorizationHeader["Bearer ".Length..].Trim(); + if (string.IsNullOrEmpty(token)) + { + return false; + } + + // Validate the token + return await _tokenValidator(token); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index ae0e7afc..7faad5c9 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -66,6 +66,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? SetResourcesHandler(options); SetSetLoggingLevelHandler(options); SetCompletionHandler(options); + SetAuthorizationHandler(); SetPingHandler(); // Register any notification handlers that were provided. @@ -503,6 +504,13 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) McpJsonUtilities.JsonContext.Default.EmptyResult); } + private void SetAuthorizationHandler() + { + // The authorization capability is handled via middleware in ASP.NET Core, + // so we don't need to set up any special handlers here. + // We just make sure to include the capability in the ServerCapabilities. + } + private ValueTask InvokeHandlerAsync( Func, CancellationToken, ValueTask> handler, TParams? args, diff --git a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs index b759ba97..7b33940b 100644 --- a/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs +++ b/src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs @@ -122,6 +122,13 @@ internal static bool IsValidMcpToolSchema(JsonElement element) [JsonSerializable(typeof(SubscribeRequestParams))] [JsonSerializable(typeof(UnsubscribeRequestParams))] [JsonSerializable(typeof(IReadOnlyDictionary))] + + // Authorization-related types + [JsonSerializable(typeof(Protocol.Auth.ProtectedResourceMetadata))] + [JsonSerializable(typeof(Protocol.Auth.AuthorizationServerMetadata))] + [JsonSerializable(typeof(Protocol.Auth.ClientMetadata))] + [JsonSerializable(typeof(Protocol.Auth.ClientRegistration))] + [JsonSerializable(typeof(Protocol.Auth.Token))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/src/ModelContextProtocol/Utils/SynchronizedValue.cs b/src/ModelContextProtocol/Utils/SynchronizedValue.cs new file mode 100644 index 00000000..106bf29b --- /dev/null +++ b/src/ModelContextProtocol/Utils/SynchronizedValue.cs @@ -0,0 +1,75 @@ +namespace ModelContextProtocol.Utils; + +/// +/// Provides a thread-safe synchronized value with locking functionality. +/// +/// The type of value to synchronize. +internal class SynchronizedValue where T : class +{ + private readonly SemaphoreSlim _semaphore = new(1, 1); + private T _value; + + /// + /// Initializes a new instance of the class. + /// + /// The initial value. + public SynchronizedValue(T initialValue) + { + _value = initialValue; + } + + /// + /// Gets the current value without locking. + /// + /// + /// This property should only be used when thread safety is not required. + /// + public T UnsafeValue => _value; + + /// + /// Acquires a lock on the value and provides access to it. + /// + /// A disposable that provides access to the value and releases the lock when disposed. + public async Task LockAsync() + { + await _semaphore.WaitAsync().ConfigureAwait(false); + return new SynchronizedValueHandle(this); + } + + /// + /// Provides a handle to access the synchronized value while holding a lock. + /// + public class SynchronizedValueHandle : IDisposable + { + private readonly SynchronizedValue _parent; + private bool _disposed; + + internal SynchronizedValueHandle(SynchronizedValue parent) + { + _parent = parent; + } + + /// + /// Gets or sets the synchronized value. + /// + public T Value + { + get => _parent._value; + set => _parent._value = value; + } + + /// + /// Releases the lock on the synchronized value. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + _disposed = true; + _parent._semaphore.Release(); + } + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs new file mode 100644 index 00000000..e4af9c91 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/Auth/ProtectedResourceMetadataTests.cs @@ -0,0 +1,61 @@ +using ModelContextProtocol.Protocol.Auth; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Protocol.Auth; + +public class ProtectedResourceMetadataTests +{ + [Fact] + public void ProtectedResourceMetadata_JsonSerialization_Works() + { + // Arrange + var metadata = new ProtectedResourceMetadata + { + Resource = new Uri("http://localhost:7071"), + AuthorizationServers = [new Uri("https://login.microsoftonline.com/tenant/v2.0")], + BearerMethodsSupported = ["header"], + ScopesSupported = ["mcp.tools", "mcp.prompts"], + ResourceDocumentation = new Uri("https://example.com/docs") + }; + + // Act + var json = JsonSerializer.Serialize(metadata); + var deserialized = JsonSerializer.Deserialize(json); + + // Assert + Assert.NotNull(deserialized); + Assert.Equal(metadata.Resource, deserialized.Resource); + Assert.Equal(metadata.AuthorizationServers[0], deserialized.AuthorizationServers[0]); + Assert.Equal("header", deserialized.BearerMethodsSupported![0]); + Assert.Equal(2, deserialized.ScopesSupported!.Length); + Assert.Contains("mcp.tools", deserialized.ScopesSupported!); + Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); + Assert.Equal(metadata.ResourceDocumentation, deserialized.ResourceDocumentation); + } + + [Fact] + public void ProtectedResourceMetadata_JsonDeserialization_WorksWithStringProperties() + { + // Arrange + var json = @"{ + ""resource"": ""http://localhost:7071"", + ""authorization_servers"": [""https://login.microsoftonline.com/tenant/v2.0""], + ""bearer_methods_supported"": [""header""], + ""scopes_supported"": [""mcp.tools"", ""mcp.prompts""], + ""resource_documentation"": ""https://example.com/docs"" + }"; + + // Act + var deserialized = JsonSerializer.Deserialize(json); + + // Assert + Assert.NotNull(deserialized); + Assert.Equal(new Uri("http://localhost:7071"), deserialized.Resource); + Assert.Equal(new Uri("https://login.microsoftonline.com/tenant/v2.0"), deserialized.AuthorizationServers[0]); + Assert.Equal("header", deserialized.BearerMethodsSupported![0]); + Assert.Equal(2, deserialized.ScopesSupported!.Length); + Assert.Contains("mcp.tools", deserialized.ScopesSupported!); + Assert.Contains("mcp.prompts", deserialized.ScopesSupported!); + Assert.Equal(new Uri("https://example.com/docs"), deserialized.ResourceDocumentation); + } +}