From 26e8c62f2fefb7d1ae484c147b36ff3876d3e799 Mon Sep 17 00:00:00 2001
From: gdsouza <glenantony_dsouza@intuit.com>
Date: Wed, 4 Jun 2025 16:47:26 +0530
Subject: [PATCH] Add support for managing users SSH key related operations

---
 README.md               |  10 ++
 pkg/github/tools.go     |   3 +
 pkg/github/user.go      | 163 +++++++++++++++++++
 pkg/github/user_test.go | 343 ++++++++++++++++++++++++++++++++++++++++
 4 files changed, 519 insertions(+)
 create mode 100644 pkg/github/user.go
 create mode 100644 pkg/github/user_test.go

diff --git a/README.md b/README.md
index 7b9e20fc3..f4ac45024 100644
--- a/README.md
+++ b/README.md
@@ -290,6 +290,16 @@ export GITHUB_MCP_TOOL_ADD_ISSUE_COMMENT_DESCRIPTION="an alternative description
 - **get_me** - Get details of the authenticated user
   - No parameters required
 
+- **list_users_public_ssh_keys** - "Lists the public SSH keys for the authenticated user's GitHub account
+  - No parameters required
+
+- **get_users_public_ssh_key** - View extended details for a single public SSH key
+  - `key_id`: Key Id (number, required)
+
+- **add_users_public_ssh_key** - Adds a public SSH key to the authenticated user's GitHub account
+  - `title`: Title of the key (string, optional)
+  - `key`: Public key contents (string, required)
+
 ### Issues
 
 - **get_issue** - Gets the contents of an issue within a repository
diff --git a/pkg/github/tools.go b/pkg/github/tools.go
index ab0528174..7448c7b50 100644
--- a/pkg/github/tools.go
+++ b/pkg/github/tools.go
@@ -56,6 +56,9 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
 	users := toolsets.NewToolset("users", "GitHub User related tools").
 		AddReadTools(
 			toolsets.NewServerTool(SearchUsers(getClient, t)),
+			toolsets.NewServerTool(ListUsersPublicSSHKeys(getClient, t)),
+			toolsets.NewServerTool(GetUsersPublicSSHKey(getClient, t)),
+			toolsets.NewServerTool(AddUsersPublicSSHKey(getClient, t)),
 		)
 	pullRequests := toolsets.NewToolset("pull_requests", "GitHub Pull Request related tools").
 		AddReadTools(
diff --git a/pkg/github/user.go b/pkg/github/user.go
new file mode 100644
index 000000000..c2a13a9c0
--- /dev/null
+++ b/pkg/github/user.go
@@ -0,0 +1,163 @@
+package github
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+
+	"github.com/github/github-mcp-server/pkg/translations"
+	"github.com/google/go-github/v72/github"
+	"github.com/mark3labs/mcp-go/mcp"
+	"github.com/mark3labs/mcp-go/server"
+)
+
+// ListUsersPublicSSHKeys creates a tool to list public ssh keys for user
+func ListUsersPublicSSHKeys(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+	return mcp.NewTool("list_users_public_ssh_keys",
+			mcp.WithDescription(t("TOOL_LIST_USERS_PUBLIC_SSH_KEYS", "Lists the public SSH keys for the authenticated user's GitHub account")),
+			mcp.WithToolAnnotation(mcp.ToolAnnotation{
+				Title:        t("TOOL_LIST_USERS_PUBLIC_SSH_KEYS_USER_TITLE", "List users public ssh keys"),
+				ReadOnlyHint: toBoolPtr(true),
+			}),
+			WithPagination(),
+		),
+		func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+			pagination, err := OptionalPaginationParams(request)
+			if err != nil {
+				return mcp.NewToolResultError(err.Error()), nil
+			}
+
+			opts := &github.ListOptions{
+				Page:    pagination.page,
+				PerPage: pagination.perPage,
+			}
+
+			client, err := getClient(ctx)
+			if err != nil {
+				return nil, fmt.Errorf("failed to get GitHub client: %w", err)
+			}
+			result, resp, err := client.Users.ListKeys(ctx, "", opts)
+			if err != nil {
+				return nil, fmt.Errorf("failed to list users ssh keys: %w", err)
+			}
+			defer func() { _ = resp.Body.Close() }()
+
+			if resp.StatusCode != 200 {
+				body, err := io.ReadAll(resp.Body)
+				if err != nil {
+					return nil, fmt.Errorf("failed to read response body: %w", err)
+				}
+				return mcp.NewToolResultError(fmt.Sprintf("failed to list users ssh keys: %s", string(body))), nil
+			}
+
+			r, err := json.Marshal(result)
+			if err != nil {
+				return nil, fmt.Errorf("failed to marshal response: %w", err)
+			}
+
+			return mcp.NewToolResultText(string(r)), nil
+		}
+}
+
+// GetUsersPublicSSHKey creates a tool to get public ssh key for user
+func GetUsersPublicSSHKey(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+	return mcp.NewTool("get_users_public_ssh_key",
+			mcp.WithDescription(t("TOOL_GET_USERS_PUBLIC_SSH_KEY", "View extended details for a single public SSH key")),
+			mcp.WithToolAnnotation(mcp.ToolAnnotation{
+				Title:        t("TOOL_GET_USERS_PUBLIC_SSH_KEY_USER_TITLE", "Get public ssh key details"),
+				ReadOnlyHint: toBoolPtr(true),
+			}),
+			mcp.WithNumber("key_id",
+				mcp.Required(),
+				mcp.Description("The unique identifier of the key"),
+			),
+		),
+		func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+			client, err := getClient(ctx)
+			if err != nil {
+				return nil, fmt.Errorf("failed to get GitHub client: %w", err)
+			}
+			keyId, err := RequiredInt(request, "key_id")
+			if err != nil {
+				return mcp.NewToolResultError(err.Error()), nil
+			}
+			result, resp, err := client.Users.GetKey(ctx, int64(keyId))
+			if err != nil {
+				return nil, fmt.Errorf("failed to get ssh key details: %w", err)
+			}
+			defer func() { _ = resp.Body.Close() }()
+
+			if resp.StatusCode != 200 {
+				body, err := io.ReadAll(resp.Body)
+				if err != nil {
+					return nil, fmt.Errorf("failed to read response body: %w", err)
+				}
+				return mcp.NewToolResultError(fmt.Sprintf("failed to get ssh key details: %s", string(body))), nil
+			}
+
+			r, err := json.Marshal(result)
+			if err != nil {
+				return nil, fmt.Errorf("failed to marshal response: %w", err)
+			}
+
+			return mcp.NewToolResultText(string(r)), nil
+		}
+}
+
+// AddPublicSSHKey Adds a public SSH key to the authenticated user's GitHub account
+func AddUsersPublicSSHKey(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
+	return mcp.NewTool("add_users_public_ssh_key",
+			mcp.WithDescription(t("TOOL_ADD_USERS_PUBLIC_SSH_KEY", "Adds a public SSH key to the authenticated user's GitHub account")),
+			mcp.WithToolAnnotation(mcp.ToolAnnotation{
+				Title:        t("TOOL_ADD_USERS_PUBLIC_SSH_KEY_USER_TITLE", "Add users public ssh key"),
+				ReadOnlyHint: toBoolPtr(true),
+			}),
+			mcp.WithString("title",
+				mcp.Description("A descriptive name for the new key"),
+			),
+			mcp.WithString("key",
+				mcp.Required(),
+				mcp.Description("The public SSH key to add to your GitHub account"),
+			),
+		),
+		func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
+
+			client, err := getClient(ctx)
+			if err != nil {
+				return nil, fmt.Errorf("failed to get GitHub client: %w", err)
+			}
+			title, err := OptionalParam[string](request, "title")
+			if err != nil {
+				return mcp.NewToolResultError(err.Error()), nil
+			}
+			key, err := requiredParam[string](request, "key")
+			if err != nil {
+				return mcp.NewToolResultError(err.Error()), nil
+			}
+			githubKey := &github.Key{
+				Title: &title,
+				Key:   &key,
+			}
+			result, resp, err := client.Users.CreateKey(ctx, githubKey)
+			if err != nil {
+				return nil, fmt.Errorf("failed to add ssh key: %w", err)
+			}
+			defer func() { _ = resp.Body.Close() }()
+
+			if resp.StatusCode != 201 {
+				body, err := io.ReadAll(resp.Body)
+				if err != nil {
+					return nil, fmt.Errorf("failed to read response body: %w", err)
+				}
+				return mcp.NewToolResultError(fmt.Sprintf("failed to add ssh key: %s", string(body))), nil
+			}
+
+			r, err := json.Marshal(result)
+			if err != nil {
+				return nil, fmt.Errorf("failed to marshal response: %w", err)
+			}
+
+			return mcp.NewToolResultText(string(r)), nil
+		}
+}
diff --git a/pkg/github/user_test.go b/pkg/github/user_test.go
new file mode 100644
index 000000000..ecc1cdb07
--- /dev/null
+++ b/pkg/github/user_test.go
@@ -0,0 +1,343 @@
+package github
+
+import (
+	"context"
+	"encoding/json"
+	"net/http"
+	"testing"
+
+	"github.com/github/github-mcp-server/pkg/translations"
+	"github.com/google/go-github/v72/github"
+	"github.com/migueleliasweb/go-github-mock/src/mock"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func Test_ListUsersPublicSSHKeys(t *testing.T) {
+	mockClient := github.NewClient(nil)
+	tool, _ := ListUsersPublicSSHKeys(stubGetClientFn(mockClient), translations.NullTranslationHelper)
+	assert.Equal(t, "list_users_public_ssh_keys", tool.Name)
+	assert.NotEmpty(t, tool.Description)
+	assert.Contains(t, tool.InputSchema.Properties, "page")
+	assert.Contains(t, tool.InputSchema.Properties, "perPage")
+
+	// Setup mock results
+	mockListSSHKeyResult := []*github.Key{
+		{
+			ID:       github.Ptr(int64(1)),
+			Key:      github.Ptr("ssh test key"),
+			URL:      github.Ptr("test url"),
+			Title:    github.Ptr("test key 1"),
+			ReadOnly: github.Ptr(true),
+			Verified: github.Ptr(true),
+		},
+		{
+			ID:       github.Ptr(int64(2)),
+			Key:      github.Ptr("ssh test key"),
+			URL:      github.Ptr("test url"),
+			Title:    github.Ptr("test key 2"),
+			ReadOnly: github.Ptr(true),
+			Verified: github.Ptr(true),
+		},
+	}
+	tests := []struct {
+		name           string
+		mockedClient   *http.Client
+		requestArgs    map[string]any
+		expectError    bool
+		expectedResult []*github.Key
+		expectedErrMsg string
+	}{
+		{
+			name: "list public ssh keys",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatchHandler(
+					mock.GetUserKeys,
+					expectQueryParams(t, map[string]string{
+						"page":     "2",
+						"per_page": "10",
+					}).andThen(
+						mockResponse(t, http.StatusOK, mockListSSHKeyResult),
+					),
+				),
+			),
+			requestArgs: map[string]any{
+				"page":    float64(2),
+				"perPage": float64(10),
+			},
+			expectError:    false,
+			expectedResult: mockListSSHKeyResult,
+		},
+		{
+			name: "list public ssh keys with default pagination",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatchHandler(
+					mock.GetUserKeys,
+					expectQueryParams(t, map[string]string{
+						"page":     "1",
+						"per_page": "30",
+					}).andThen(
+						mockResponse(t, http.StatusOK, mockListSSHKeyResult),
+					),
+				),
+			),
+			expectError:    false,
+			expectedResult: mockListSSHKeyResult,
+		},
+		{
+			name: "list ssh key fails",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatchHandler(
+					mock.GetUserKeys,
+					http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+						w.WriteHeader(http.StatusUnauthorized)
+						_, _ = w.Write([]byte(`{"message": "bad permission"}`))
+					}),
+				),
+			),
+			expectError:    true,
+			expectedErrMsg: "failed to list users ssh keys",
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			// Setup client with mock
+			client := github.NewClient(tc.mockedClient)
+			_, handler := ListUsersPublicSSHKeys(stubGetClientFn(client), translations.NullTranslationHelper)
+
+			// Create call request
+			request := createMCPRequest(tc.requestArgs)
+
+			// Call handler
+			result, err := handler(context.Background(), request)
+
+			// Verify results
+			if tc.expectError {
+				require.Error(t, err)
+				assert.Contains(t, err.Error(), tc.expectedErrMsg)
+				return
+			}
+
+			require.NoError(t, err)
+
+			// Parse the result and get the text content if no error
+			textContent := getTextResult(t, result)
+
+			// Unmarshal and verify the result
+			var returnedResult []*github.Key
+			err = json.Unmarshal([]byte(textContent.Text), &returnedResult)
+			require.NoError(t, err)
+			assert.Equal(t, len(tc.expectedResult), len(returnedResult))
+			for i, keyData := range returnedResult {
+				assert.Equal(t, tc.expectedResult[i].ID, keyData.ID)
+				assert.Equal(t, tc.expectedResult[i].Title, keyData.Title)
+				assert.Equal(t, tc.expectedResult[i].URL, keyData.URL)
+				assert.Equal(t, tc.expectedResult[i].Key, keyData.Key)
+				assert.Equal(t, tc.expectedResult[i].Verified, keyData.Verified)
+				assert.Equal(t, tc.expectedResult[i].ReadOnly, keyData.ReadOnly)
+			}
+		})
+	}
+}
+
+func Test_GetUsersPublicSSHKey(t *testing.T) {
+	mockClient := github.NewClient(nil)
+	tool, _ := GetUsersPublicSSHKey(stubGetClientFn(mockClient), translations.NullTranslationHelper)
+	assert.Equal(t, "get_users_public_ssh_key", tool.Name)
+	assert.NotEmpty(t, tool.Description)
+	assert.Contains(t, tool.InputSchema.Properties, "key_id")
+	assert.NotContains(t, tool.InputSchema.Properties, "page")
+	assert.NotContains(t, tool.InputSchema.Properties, "perPage")
+
+	// Setup mock results
+	mockGetSSHKeyResult := &github.Key{
+		ID:       github.Ptr(int64(1)),
+		Key:      github.Ptr("ssh test key"),
+		URL:      github.Ptr("test url"),
+		Title:    github.Ptr("test key 1"),
+		ReadOnly: github.Ptr(true),
+		Verified: github.Ptr(true),
+	}
+	tests := []struct {
+		name           string
+		mockedClient   *http.Client
+		requestArgs    map[string]any
+		expectError    bool
+		expectedResult *github.Key
+		expectedErrMsg string
+	}{
+		{
+			name: "get public ssh key",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatchHandler(
+					mock.GetUserKeysByKeyId,
+					expectPath(t, "/user/keys/1").
+						andThen(
+							mockResponse(t, http.StatusOK, mockGetSSHKeyResult),
+						),
+				),
+			),
+			requestArgs: map[string]any{
+				"key_id": float64(1),
+			},
+			expectError:    false,
+			expectedResult: mockGetSSHKeyResult,
+		},
+		{
+			name: "get public ssh key with bad wrong key",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatchHandler(
+					mock.GetUserKeysByKeyId,
+					http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+						w.WriteHeader(http.StatusNotFound)
+						_, _ = w.Write([]byte(`{"message": "key not found"}`))
+					}),
+				),
+			),
+			requestArgs: map[string]any{
+				"key_id": float64(2),
+			},
+			expectError:    true,
+			expectedErrMsg: "failed to get ssh key details",
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			// Setup client with mock
+			client := github.NewClient(tc.mockedClient)
+			_, handler := GetUsersPublicSSHKey(stubGetClientFn(client), translations.NullTranslationHelper)
+
+			// Create call request
+			request := createMCPRequest(tc.requestArgs)
+
+			// Call handler
+			result, err := handler(context.Background(), request)
+
+			// Verify results
+			if tc.expectError {
+				require.Error(t, err)
+				assert.Contains(t, err.Error(), tc.expectedErrMsg)
+				return
+			}
+
+			require.NoError(t, err)
+
+			// Parse the result and get the text content if no error
+			textContent := getTextResult(t, result)
+			// Unmarshal and verify the result
+			var returnedResult *github.Key
+			err = json.Unmarshal([]byte(textContent.Text), &returnedResult)
+			require.NoError(t, err)
+			assert.Equal(t, tc.expectedResult.ID, returnedResult.ID)
+			assert.Equal(t, tc.expectedResult.Key, returnedResult.Key)
+			assert.Equal(t, tc.expectedResult.Title, returnedResult.Title)
+			assert.Equal(t, tc.expectedResult.URL, returnedResult.URL)
+			assert.Equal(t, tc.expectedResult.Verified, returnedResult.Verified)
+			assert.Equal(t, tc.expectedResult.ReadOnly, returnedResult.ReadOnly)
+		})
+	}
+}
+
+func Test_AddUsersPublicSSHKey(t *testing.T) {
+	mockClient := github.NewClient(nil)
+	tool, _ := AddUsersPublicSSHKey(stubGetClientFn(mockClient), translations.NullTranslationHelper)
+	assert.Equal(t, "add_users_public_ssh_key", tool.Name)
+	assert.NotEmpty(t, tool.Description)
+	assert.Contains(t, tool.InputSchema.Properties, "title")
+	assert.Contains(t, tool.InputSchema.Properties, "key")
+	assert.NotContains(t, tool.InputSchema.Properties, "page")
+	assert.NotContains(t, tool.InputSchema.Properties, "perPage")
+
+	// Setup mock results
+	mockAddKeyResult := &github.Key{
+		ID:       github.Ptr(int64(1)),
+		Key:      github.Ptr("ssh test key"),
+		URL:      github.Ptr("test url"),
+		Title:    github.Ptr("test key 1"),
+		ReadOnly: github.Ptr(true),
+		Verified: github.Ptr(true),
+	}
+	tests := []struct {
+		name           string
+		mockedClient   *http.Client
+		requestArgs    map[string]any
+		expectError    bool
+		expectedResult *github.Key
+		expectedErrMsg string
+	}{
+		{
+			name: "add public ssh key",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatchHandler(
+					mock.PostUserKeys,
+					expectRequestBody(t, map[string]any{
+						"title": "test key 1",
+						"key":   "ssh test key",
+					}).
+						andThen(
+							mockResponse(t, http.StatusCreated, mockAddKeyResult),
+						),
+				),
+			),
+			requestArgs: map[string]any{
+				"title": "test key 1",
+				"key":   "ssh test key",
+			},
+			expectError:    false,
+			expectedResult: mockAddKeyResult,
+		},
+		{
+			name: "add public ssh key fails",
+			mockedClient: mock.NewMockedHTTPClient(
+				mock.WithRequestMatchHandler(
+					mock.PostUserKeys,
+					http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+						w.WriteHeader(http.StatusInternalServerError)
+						_, _ = w.Write([]byte(`{"message": "something bad happened"}`))
+					}),
+				),
+			),
+			requestArgs: map[string]any{
+				"title": "test key 1",
+				"key":   "ssh test key",
+			},
+			expectError:    true,
+			expectedErrMsg: "failed to add ssh key",
+		},
+	}
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			// Setup client with mock
+			client := github.NewClient(tc.mockedClient)
+			_, handler := AddUsersPublicSSHKey(stubGetClientFn(client), translations.NullTranslationHelper)
+
+			// Create call request
+			request := createMCPRequest(tc.requestArgs)
+
+			// Call handler
+			result, err := handler(context.Background(), request)
+
+			// Verify results
+			if tc.expectError {
+				require.Error(t, err)
+				assert.Contains(t, err.Error(), tc.expectedErrMsg)
+				return
+			}
+
+			require.NoError(t, err)
+
+			// Parse the result and get the text content if no error
+			textContent := getTextResult(t, result)
+			// Unmarshal and verify the result
+			var returnedResult *github.Key
+			err = json.Unmarshal([]byte(textContent.Text), &returnedResult)
+			require.NoError(t, err)
+			assert.Equal(t, tc.expectedResult.ID, returnedResult.ID)
+			assert.Equal(t, tc.expectedResult.Key, returnedResult.Key)
+			assert.Equal(t, tc.expectedResult.Title, returnedResult.Title)
+			assert.Equal(t, tc.expectedResult.URL, returnedResult.URL)
+			assert.Equal(t, tc.expectedResult.Verified, returnedResult.Verified)
+			assert.Equal(t, tc.expectedResult.ReadOnly, returnedResult.ReadOnly)
+		})
+	}
+}