diff --git a/client/stdio_test.go b/client/stdio_test.go index 7bffa3b2..48514d91 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -7,7 +7,7 @@ import ( "log/slog" "os" "os/exec" - "path/filepath" + "runtime" "sync" "testing" "time" @@ -19,13 +19,18 @@ func compileTestServer(outputPath string) error { cmd := exec.Command( "go", "build", + "-buildmode=pie", "-o", outputPath, "../testdata/mockstdio_server.go", ) + tmpCache, _ := os.MkdirTemp("", "gocache") + cmd.Env = append(os.Environ(), "GOCACHE="+tmpCache) + if output, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) } + // Verify the binary was actually created if _, err := os.Stat(outputPath); os.IsNotExist(err) { return fmt.Errorf("mock server binary not found at %s after compilation", outputPath) } @@ -33,10 +38,22 @@ func compileTestServer(outputPath string) error { } func TestStdioMCPClient(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + + // Add .exe suffix on Windows + if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first + mockServerPath += ".exe" + } + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index aa728ec6..cb25bf79 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "os/exec" - "path/filepath" "runtime" "sync" "testing" @@ -19,25 +18,41 @@ func compileTestServer(outputPath string) error { cmd := exec.Command( "go", "build", + "-buildmode=pie", "-o", outputPath, "../../testdata/mockstdio_server.go", ) + tmpCache, _ := os.MkdirTemp("", "gocache") + cmd.Env = append(os.Environ(), "GOCACHE="+tmpCache) + if output, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("compilation failed: %v\nOutput: %s", err, output) } + // Verify the binary was actually created + if _, err := os.Stat(outputPath); os.IsNotExist(err) { + return fmt.Errorf("mock server binary not found at %s after compilation", outputPath) + } return nil } func TestStdio(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -48,9 +63,9 @@ func TestStdio(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - err := stdio.Start(ctx) - if err != nil { - t.Fatalf("Failed to start Stdio transport: %v", err) + startErr := stdio.Start(ctx) + if startErr != nil { + t.Fatalf("Failed to start Stdio transport: %v", startErr) } defer stdio.Close() @@ -307,13 +322,22 @@ func TestStdioErrors(t *testing.T) { }) t.Run("RequestBeforeStart", func(t *testing.T) { - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -328,23 +352,31 @@ func TestStdioErrors(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() - _, err := uninitiatedStdio.SendRequest(ctx, request) - if err == nil { + _, reqErr := uninitiatedStdio.SendRequest(ctx, request) + if reqErr == nil { t.Errorf("Expected SendRequest to panic before Start(), but it didn't") - } else if err.Error() != "stdio client not started" { - t.Errorf("Expected error 'stdio client not started', got: %v", err) + } else if reqErr.Error() != "stdio client not started" { + t.Errorf("Expected error 'stdio client not started', got: %v", reqErr) } }) t.Run("RequestAfterClose", func(t *testing.T) { - // Compile mock server - mockServerPath := filepath.Join(os.TempDir(), "mockstdio_server") + // Create a temporary file for the mock server + tempFile, err := os.CreateTemp("", "mockstdio_server") + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + tempFile.Close() + mockServerPath := tempFile.Name() + // Add .exe suffix on Windows if runtime.GOOS == "windows" { + os.Remove(mockServerPath) // Remove the empty file first mockServerPath += ".exe" } - if err := compileTestServer(mockServerPath); err != nil { - t.Fatalf("Failed to compile mock server: %v", err) + + if compileErr := compileTestServer(mockServerPath); compileErr != nil { + t.Fatalf("Failed to compile mock server: %v", compileErr) } defer os.Remove(mockServerPath) @@ -353,8 +385,8 @@ func TestStdioErrors(t *testing.T) { // Start the transport ctx := context.Background() - if err := stdio.Start(ctx); err != nil { - t.Fatalf("Failed to start Stdio transport: %v", err) + if startErr := stdio.Start(ctx); startErr != nil { + t.Fatalf("Failed to start Stdio transport: %v", startErr) } // Close the transport - ignore errors like "broken pipe" since the process might exit already @@ -370,8 +402,8 @@ func TestStdioErrors(t *testing.T) { Method: "ping", } - _, err := stdio.SendRequest(ctx, request) - if err == nil { + _, sendErr := stdio.SendRequest(ctx, request) + if sendErr == nil { t.Errorf("Expected error when sending request after close, got nil") } }) diff --git a/mcp/types.go b/mcp/types.go index c79baae1..516f90b4 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -132,7 +132,7 @@ type NotificationParams struct { } // MarshalJSON implements custom JSON marshaling -func (p *NotificationParams) MarshalJSON() ([]byte, error) { +func (p NotificationParams) MarshalJSON() ([]byte, error) { // Create a map to hold all fields m := make(map[string]interface{}) diff --git a/server/server.go b/server/server.go index 95831ebd..8aac05ca 100644 --- a/server/server.go +++ b/server/server.go @@ -856,7 +856,7 @@ func (s *MCPServer) handleToolCall( session := ClientSessionFromContext(ctx) if session != nil { - if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk { if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { var sessionOk bool tool, sessionOk = sessionTools[request.Params.Name] diff --git a/server/session_test.go b/server/session_test.go index d1d0bc79..42def221 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "encoding/json" "errors" "sync" "testing" @@ -295,6 +296,64 @@ func TestMCPServer_AddSessionTool(t *testing.T) { assert.Contains(t, session.GetSessionTools(), "session-tool-helper") } +func TestMCPServer_CallSessionTool(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) + + // Add global tool + server.AddTool(mcp.NewTool("test_tool"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("global result"), nil + }) + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithTools{ + sessionID: "session-1", + notificationChannel: sessionChan, + initialized: true, + } + + // Register the session + err := server.RegisterSession(context.Background(), session) + require.NoError(t, err) + + // Add session-specific tool with the same name to override the global tool + err = server.AddSessionTool( + session.SessionID(), + mcp.NewTool("test_tool"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("session result"), nil + }, + ) + require.NoError(t, err) + + // Call the tool using session context + sessionCtx := server.WithContext(context.Background(), session) + toolRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "test_tool", + }, + } + requestBytes, err := json.Marshal(toolRequest) + if err != nil { + t.Fatalf("Failed to marshal tool request: %v", err) + } + + response := server.HandleMessage(sessionCtx, requestBytes) + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + + callToolResult, ok := resp.Result.(mcp.CallToolResult) + assert.True(t, ok) + + // Since we specify a tool with the same name for current session, the expected text should be "session result" + if text := callToolResult.Content[0].(mcp.TextContent).Text; text != "session result" { + t.Errorf("Expected result 'session result', got %q", text) + } +} + func TestMCPServer_DeleteSessionTools(t *testing.T) { server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true)) ctx := context.Background() diff --git a/server/sse.go b/server/sse.go index e380d20a..018657e6 100644 --- a/server/sse.go +++ b/server/sse.go @@ -28,6 +28,7 @@ type sseSession struct { requestID atomic.Int64 notificationChannel chan mcp.JSONRPCNotification initialized atomic.Bool + tools sync.Map // stores session-specific tools } // SSEContextFunc is a function that takes an existing context and the current @@ -58,7 +59,34 @@ func (s *sseSession) Initialized() bool { return s.initialized.Load() } -var _ ClientSession = (*sseSession)(nil) +func (s *sseSession) GetSessionTools() map[string]ServerTool { + tools := make(map[string]ServerTool) + s.tools.Range(func(key, value interface{}) bool { + if tool, ok := value.(ServerTool); ok { + tools[key.(string)] = tool + } + return true + }) + return tools +} + +func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { + // Clear existing tools + s.tools.Range(func(key, _ interface{}) bool { + s.tools.Delete(key) + return true + }) + + // Set new tools + for name, tool := range tools { + s.tools.Store(name, tool) + } +} + +var ( + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) +) // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. @@ -107,13 +135,22 @@ func WithBaseURL(baseURL string) SSEOption { } } -// WithBasePath adds a new option for setting a static base path -func WithBasePath(basePath string) SSEOption { +// WithStaticBasePath adds a new option for setting a static base path +func WithStaticBasePath(basePath string) SSEOption { return func(s *SSEServer) { s.basePath = normalizeURLPath(basePath) } } +// WithBasePath adds a new option for setting a static base path. +// +// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version. +// +//go:deprecated +func WithBasePath(basePath string) SSEOption { + return WithStaticBasePath(basePath) +} + // WithDynamicBasePath accepts a function for generating the base path. This is // useful for cases where the base path is not known at the time of SSE server // creation, such as when using a reverse proxy or when the server is mounted @@ -339,7 +376,12 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { } messageBytes, _ := json.Marshal(message) pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) - session.eventQueue <- pingMsg + select { + case session.eventQueue <- pingMsg: + // Message sent successfully + case <-session.done: + return + } case <-session.done: return case <-r.Context().Done(): @@ -423,13 +465,21 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } + // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled. + // this is required because the http ctx will be canceled when the client disconnects + detachedCtx := context.WithoutCancel(ctx) + // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE w.WriteHeader(http.StatusAccepted) - go func() { + // Create a new context for handling the message that will be canceled when the message handling is done + messageCtx, cancel := context.WithCancel(detachedCtx) + + go func(ctx context.Context) { + defer cancel() + // Use the context that will be canceled when session is done // Process message through MCPServer response := s.server.HandleMessage(ctx, rawMessage) - // Only send response if there is one (not for notifications) if response != nil { var message string @@ -437,7 +487,6 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { // If there is an error marshalling the response, send a generic error response log.Printf("failed to marshal response: %v", err) message = fmt.Sprintf("event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n") - return } else { message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) } @@ -453,7 +502,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { log.Printf("Event queue full for session %s", sessionID) } } - }() + }(messageCtx) } // writeJSONRPCError writes a JSON-RPC error response with the given error details. diff --git a/server/sse_test.go b/server/sse_test.go index a121581a..393a70cf 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -24,7 +24,7 @@ func TestSSEServer(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") sseServer := NewSSEServer(mcpServer, WithBaseURL("http://localhost:8080"), - WithBasePath("/mcp"), + WithStaticBasePath("/mcp"), ) if sseServer == nil { @@ -62,7 +62,7 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event - endpointEvent, err := readSeeEvent(sseResp) + endpointEvent, err := readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -195,7 +195,7 @@ func TestSSEServer(t *testing.T) { } defer resp.Body.Close() - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -203,7 +203,6 @@ func TestSSEServer(t *testing.T) { strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) - fmt.Printf("========> %v", respFromSee) var response map[string]interface{} if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { t.Errorf( @@ -499,7 +498,7 @@ func TestSSEServer(t *testing.T) { t.Run("works as http.Handler with custom basePath", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") - sseServer := NewSSEServer(mcpServer, WithBasePath("/mcp")) + sseServer := NewSSEServer(mcpServer, WithStaticBasePath("/mcp")) ts := httptest.NewServer(sseServer) defer ts.Close() @@ -590,7 +589,7 @@ func TestSSEServer(t *testing.T) { defer sseResp.Body.Close() // Read the endpoint event - endpointEvent, err := readSeeEvent(sseResp) + endpointEvent, err := readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } @@ -632,16 +631,16 @@ func TestSSEServer(t *testing.T) { } // Verify response - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } - respFromSee := strings.TrimSpace( + respFromSSE := strings.TrimSpace( strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) var response map[string]interface{} - if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSSE)).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -666,7 +665,8 @@ func TestSSEServer(t *testing.T) { t.Fatalf("Failed to marshal tool request: %v", err) } - req, err := http.NewRequest(http.MethodPost, messageURL, bytes.NewBuffer(requestBody)) + var req *http.Request + req, err = http.NewRequest(http.MethodPost, messageURL, bytes.NewBuffer(requestBody)) if err != nil { t.Fatalf("Failed to create tool request: %v", err) } @@ -679,17 +679,17 @@ func TestSSEServer(t *testing.T) { } defer resp.Body.Close() - endpointEvent, err = readSeeEvent(sseResp) + endpointEvent, err = readSSEEvent(sseResp) if err != nil { t.Fatalf("Failed to read SSE response: %v", err) } - respFromSee = strings.TrimSpace( + respFromSSE = strings.TrimSpace( strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], ) response = make(map[string]interface{}) - if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil { + if err := json.NewDecoder(strings.NewReader(respFromSSE)).Decode(&response); err != nil { t.Fatalf("Failed to decode response: %v", err) } @@ -716,7 +716,7 @@ func TestSSEServer(t *testing.T) { useFullURLForMessageEndpoint := false srv := &http.Server{} rands := []SSEOption{ - WithBasePath(basePath), + WithStaticBasePath(basePath), WithBaseURL(baseURL), WithMessageEndpoint(messageEndpoint), WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint), @@ -1129,9 +1129,280 @@ func TestSSEServer(t *testing.T) { }) } }) + + t.Run("SessionWithTools implementation", func(t *testing.T) { + // Create hooks to track sessions + hooks := &Hooks{} + var registeredSession *sseSession + hooks.AddOnRegisterSession(func(ctx context.Context, session ClientSession) { + if s, ok := session.(*sseSession); ok { + registeredSession = s + } + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // Connect to SSE endpoint + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event to ensure session is established + _, err = readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + // Verify we got a session + if registeredSession == nil { + t.Fatal("Session was not registered via hook") + } + + // Test setting and getting tools + tools := map[string]ServerTool{ + "test_tool": { + Tool: mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + Annotations: mcp.ToolAnnotation{ + Title: "Test Tool", + }, + }, + Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("test"), nil + }, + }, + } + + // Test SetSessionTools + registeredSession.SetSessionTools(tools) + + // Test GetSessionTools + retrievedTools := registeredSession.GetSessionTools() + if len(retrievedTools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(retrievedTools)) + } + if tool, exists := retrievedTools["test_tool"]; !exists { + t.Error("Expected test_tool to exist") + } else if tool.Tool.Name != "test_tool" { + t.Errorf("Expected tool name test_tool, got %s", tool.Tool.Name) + } + + // Test concurrent access + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(2) + go func(i int) { + defer wg.Done() + tools := map[string]ServerTool{ + fmt.Sprintf("tool_%d", i): { + Tool: mcp.Tool{ + Name: fmt.Sprintf("tool_%d", i), + Description: fmt.Sprintf("Tool %d", i), + Annotations: mcp.ToolAnnotation{ + Title: fmt.Sprintf("Tool %d", i), + }, + }, + }, + } + registeredSession.SetSessionTools(tools) + }(i) + go func() { + defer wg.Done() + _ = registeredSession.GetSessionTools() + }() + } + wg.Wait() + + // Verify we can still get and set tools after concurrent access + finalTools := map[string]ServerTool{ + "final_tool": { + Tool: mcp.Tool{ + Name: "final_tool", + Description: "Final Tool", + Annotations: mcp.ToolAnnotation{ + Title: "Final Tool", + }, + }, + }, + } + registeredSession.SetSessionTools(finalTools) + retrievedTools = registeredSession.GetSessionTools() + if len(retrievedTools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(retrievedTools)) + } + if _, exists := retrievedTools["final_tool"]; !exists { + t.Error("Expected final_tool to exist") + } + }) + + t.Run("TestServerResponseMarshalError", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0", + WithResourceCapabilities(true, true), + WithHooks(&Hooks{ + OnAfterInitialize: []OnAfterInitializeFunc{ + func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { + result.Result.Meta = map[string]interface{}{"invalid": func() {}} // marshal will fail + }, + }, + }), + ) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // Connect to SSE endpoint + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event + endpointEvent, err := readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + if !strings.Contains(endpointEvent, "event: endpoint") { + t.Fatalf("Expected endpoint event, got: %s", endpointEvent) + } + + // Extract message endpoint URL + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + // Send initialize request + initRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]interface{}{ + "protocolVersion": "2024-11-05", + "clientInfo": map[string]interface{}{ + "name": "test-client", + "version": "1.0.0", + }, + }, + } + + requestBody, err := json.Marshal(initRequest) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + resp, err := http.Post( + messageURL, + "application/json", + bytes.NewBuffer(requestBody), + ) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp.StatusCode) + } + + endpointEvent, err = readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + + if !strings.Contains(endpointEvent, "\"id\": null") { + t.Errorf("Expected id to be null") + } + }) + + t.Run("Message processing continues after we return back result to client", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + + processingCompleted := make(chan struct{}) + processingStarted := make(chan struct{}) + + mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + close(processingStarted) // signal for processing started + + select { + case <-ctx.Done(): // If this happens, the test will fail because processingCompleted won't be closed + return nil, fmt.Errorf("context was canceled") + case <-time.After(1 * time.Second): // Simulate processing time + // Successfully completed processing, now close the completed channel to signal completion + close(processingCompleted) + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "success", + }, + }, + }, nil + } + }) + + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + require.NoError(t, err, "Failed to connect to SSE endpoint") + defer sseResp.Body.Close() + + endpointEvent, err := readSSEEvent(sseResp) + require.NoError(t, err, "Failed to read SSE response") + require.Contains(t, endpointEvent, "event: endpoint", "Expected endpoint event") + + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + messageRequest := map[string]interface{}{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]interface{}{ + "name": "slowMethod", + "parameters": map[string]interface{}{}, + }, + } + + requestBody, err := json.Marshal(messageRequest) + require.NoError(t, err, "Failed to marshal request") + + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, "POST", messageURL, bytes.NewBuffer(requestBody)) + require.NoError(t, err, "Failed to create request") + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err, "Failed to send message") + defer resp.Body.Close() + + require.Equal(t, http.StatusAccepted, resp.StatusCode, "Expected status 202 Accepted") + + // Wait for processing to start + select { + case <-processingStarted: // Processing has started, now cancel the client context to simulate disconnection + case <-time.After(2 * time.Second): + t.Fatal("Timed out waiting for processing to start") + } + + cancel() // cancel the client context to simulate disconnection + + // wait for processing to complete, if the test passes, it means the processing continued despite client disconnection + select { + case <-processingCompleted: + case <-time.After(2 * time.Second): + t.Fatal("Processing did not complete after client disconnection") + } + }) } -func readSeeEvent(sseResp *http.Response) (string, error) { +func readSSEEvent(sseResp *http.Response) (string, error) { buf := make([]byte, 1024) n, err := sseResp.Body.Read(buf) if err != nil {