diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index d7b42becd..f55395299 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -137,8 +137,11 @@ func runStdioServer(cfg runConfig) error { t, dumpTranslations := translations.TranslationHelper() + getClient := func(_ context.Context) (*gogithub.Client, error) { + return ghClient, nil // closing over client + } // Create - ghServer := github.NewServer(ghClient, version, cfg.readOnly, t) + ghServer := github.NewServer(getClient, version, cfg.readOnly, t) stdioServer := server.NewStdioServer(ghServer) stdLogger := stdlog.New(cfg.logger.Writer(), "stdioserver", 0) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index dc48bdb3e..4fc029bf6 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -13,7 +13,7 @@ import ( "github.com/mark3labs/mcp-go/server" ) -func GetCodeScanningAlert(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_code_scanning_alert", mcp.WithDescription(t("TOOL_GET_CODE_SCANNING_ALERT_DESCRIPTION", "Get details of a specific code scanning alert in a GitHub repository.")), mcp.WithString("owner", @@ -43,6 +43,11 @@ func GetCodeScanningAlert(client *github.Client, t translations.TranslationHelpe return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber)) if err != nil { return nil, fmt.Errorf("failed to get alert: %w", err) @@ -66,7 +71,7 @@ func GetCodeScanningAlert(client *github.Client, t translations.TranslationHelpe } } -func ListCodeScanningAlerts(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_code_scanning_alerts", mcp.WithDescription(t("TOOL_LIST_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts in a GitHub repository.")), mcp.WithString("owner", @@ -110,6 +115,10 @@ func ListCodeScanningAlerts(client *github.Client, t translations.TranslationHel return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity}) if err != nil { return nil, fmt.Errorf("failed to list alerts: %w", err) diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index f1f3a1dee..c9895e269 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -16,7 +16,7 @@ import ( func Test_GetCodeScanningAlert(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetCodeScanningAlert(mockClient, translations.NullTranslationHelper) + tool, _ := GetCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_code_scanning_alert", tool.Name) assert.NotEmpty(t, tool.Description) @@ -82,7 +82,7 @@ func Test_GetCodeScanningAlert(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetCodeScanningAlert(client, translations.NullTranslationHelper) + _, handler := GetCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -118,7 +118,7 @@ func Test_GetCodeScanningAlert(t *testing.T) { func Test_ListCodeScanningAlerts(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := ListCodeScanningAlerts(mockClient, translations.NullTranslationHelper) + tool, _ := ListCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_code_scanning_alerts", tool.Name) assert.NotEmpty(t, tool.Description) @@ -201,7 +201,7 @@ func Test_ListCodeScanningAlerts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListCodeScanningAlerts(client, translations.NullTranslationHelper) + _, handler := ListCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index c983fa269..16c34141c 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -15,7 +15,7 @@ import ( ) // GetIssue creates a tool to get details of a specific issue in a GitHub repository. -func GetIssue(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_issue", mcp.WithDescription(t("TOOL_GET_ISSUE_DESCRIPTION", "Get details of a specific issue in a GitHub repository")), mcp.WithString("owner", @@ -45,6 +45,10 @@ func GetIssue(client *github.Client, t translations.TranslationHelperFunc) (tool return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber) if err != nil { return nil, fmt.Errorf("failed to get issue: %w", err) @@ -69,7 +73,7 @@ func GetIssue(client *github.Client, t translations.TranslationHelperFunc) (tool } // AddIssueComment creates a tool to add a comment to an issue. -func AddIssueComment(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("add_issue_comment", mcp.WithDescription(t("TOOL_ADD_ISSUE_COMMENT_DESCRIPTION", "Add a comment to an existing issue")), mcp.WithString("owner", @@ -111,6 +115,10 @@ func AddIssueComment(client *github.Client, t translations.TranslationHelperFunc Body: github.Ptr(body), } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } createdComment, resp, err := client.Issues.CreateComment(ctx, owner, repo, issueNumber, comment) if err != nil { return nil, fmt.Errorf("failed to create comment: %w", err) @@ -135,7 +143,7 @@ func AddIssueComment(client *github.Client, t translations.TranslationHelperFunc } // SearchIssues creates a tool to search for issues and pull requests. -func SearchIssues(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_issues", mcp.WithDescription(t("TOOL_SEARCH_ISSUES_DESCRIPTION", "Search for issues and pull requests across GitHub repositories")), mcp.WithString("q", @@ -191,6 +199,10 @@ func SearchIssues(client *github.Client, t translations.TranslationHelperFunc) ( }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } result, resp, err := client.Search.Issues(ctx, query, opts) if err != nil { return nil, fmt.Errorf("failed to search issues: %w", err) @@ -215,7 +227,7 @@ func SearchIssues(client *github.Client, t translations.TranslationHelperFunc) ( } // CreateIssue creates a tool to create a new issue in a GitHub repository. -func CreateIssue(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_issue", mcp.WithDescription(t("TOOL_CREATE_ISSUE_DESCRIPTION", "Create a new issue in a GitHub repository")), mcp.WithString("owner", @@ -305,6 +317,10 @@ func CreateIssue(client *github.Client, t translations.TranslationHelperFunc) (t Milestone: milestoneNum, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } issue, resp, err := client.Issues.Create(ctx, owner, repo, issueRequest) if err != nil { return nil, fmt.Errorf("failed to create issue: %w", err) @@ -329,7 +345,7 @@ func CreateIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // ListIssues creates a tool to list and filter repository issues -func ListIssues(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_issues", mcp.WithDescription(t("TOOL_LIST_ISSUES_DESCRIPTION", "List issues in a GitHub repository with filtering options")), mcp.WithString("owner", @@ -419,6 +435,10 @@ func ListIssues(client *github.Client, t translations.TranslationHelperFunc) (to opts.PerPage = int(perPage) } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } issues, resp, err := client.Issues.ListByRepo(ctx, owner, repo, opts) if err != nil { return nil, fmt.Errorf("failed to list issues: %w", err) @@ -443,7 +463,7 @@ func ListIssues(client *github.Client, t translations.TranslationHelperFunc) (to } // UpdateIssue creates a tool to update an existing issue in a GitHub repository. -func UpdateIssue(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("update_issue", mcp.WithDescription(t("TOOL_UPDATE_ISSUE_DESCRIPTION", "Update an existing issue in a GitHub repository")), mcp.WithString("owner", @@ -557,6 +577,10 @@ func UpdateIssue(client *github.Client, t translations.TranslationHelperFunc) (t issueRequest.Milestone = &milestoneNum } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } updatedIssue, resp, err := client.Issues.Edit(ctx, owner, repo, issueNumber, issueRequest) if err != nil { return nil, fmt.Errorf("failed to update issue: %w", err) @@ -581,7 +605,7 @@ func UpdateIssue(client *github.Client, t translations.TranslationHelperFunc) (t } // GetIssueComments creates a tool to get comments for a GitHub issue. -func GetIssueComments(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_issue_comments", mcp.WithDescription(t("TOOL_GET_ISSUE_COMMENTS_DESCRIPTION", "Get comments for a GitHub issue")), mcp.WithString("owner", @@ -632,6 +656,10 @@ func GetIssueComments(client *github.Client, t translations.TranslationHelperFun }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } comments, resp, err := client.Issues.ListComments(ctx, owner, repo, issueNumber, opts) if err != nil { return nil, fmt.Errorf("failed to get issue comments: %w", err) diff --git a/pkg/github/issues_test.go b/pkg/github/issues_test.go index e8b16e024..61ca0ae7a 100644 --- a/pkg/github/issues_test.go +++ b/pkg/github/issues_test.go @@ -18,7 +18,7 @@ import ( func Test_GetIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetIssue(mockClient, translations.NullTranslationHelper) + tool, _ := GetIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_issue", tool.Name) assert.NotEmpty(t, tool.Description) @@ -82,7 +82,7 @@ func Test_GetIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetIssue(client, translations.NullTranslationHelper) + _, handler := GetIssue(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -114,7 +114,7 @@ func Test_GetIssue(t *testing.T) { func Test_AddIssueComment(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := AddIssueComment(mockClient, translations.NullTranslationHelper) + tool, _ := AddIssueComment(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "add_issue_comment", tool.Name) assert.NotEmpty(t, tool.Description) @@ -185,7 +185,7 @@ func Test_AddIssueComment(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := AddIssueComment(client, translations.NullTranslationHelper) + _, handler := AddIssueComment(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := mcp.CallToolRequest{ @@ -237,7 +237,7 @@ func Test_AddIssueComment(t *testing.T) { func Test_SearchIssues(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := SearchIssues(mockClient, translations.NullTranslationHelper) + tool, _ := SearchIssues(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "search_issues", tool.Name) assert.NotEmpty(t, tool.Description) @@ -352,7 +352,7 @@ func Test_SearchIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchIssues(client, translations.NullTranslationHelper) + _, handler := SearchIssues(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -393,7 +393,7 @@ func Test_SearchIssues(t *testing.T) { func Test_CreateIssue(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := CreateIssue(mockClient, translations.NullTranslationHelper) + tool, _ := CreateIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_issue", tool.Name) assert.NotEmpty(t, tool.Description) @@ -506,7 +506,7 @@ func Test_CreateIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateIssue(client, translations.NullTranslationHelper) + _, handler := CreateIssue(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -567,7 +567,7 @@ func Test_CreateIssue(t *testing.T) { func Test_ListIssues(t *testing.T) { // Verify tool definition mockClient := github.NewClient(nil) - tool, _ := ListIssues(mockClient, translations.NullTranslationHelper) + tool, _ := ListIssues(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_issues", tool.Name) assert.NotEmpty(t, tool.Description) @@ -698,7 +698,7 @@ func Test_ListIssues(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListIssues(client, translations.NullTranslationHelper) + _, handler := ListIssues(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -743,7 +743,7 @@ func Test_ListIssues(t *testing.T) { func Test_UpdateIssue(t *testing.T) { // Verify tool definition mockClient := github.NewClient(nil) - tool, _ := UpdateIssue(mockClient, translations.NullTranslationHelper) + tool, _ := UpdateIssue(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "update_issue", tool.Name) assert.NotEmpty(t, tool.Description) @@ -882,7 +882,7 @@ func Test_UpdateIssue(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdateIssue(client, translations.NullTranslationHelper) + _, handler := UpdateIssue(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1000,7 +1000,7 @@ func Test_ParseISOTimestamp(t *testing.T) { func Test_GetIssueComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetIssueComments(mockClient, translations.NullTranslationHelper) + tool, _ := GetIssueComments(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_issue_comments", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1100,7 +1100,7 @@ func Test_GetIssueComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetIssueComments(client, translations.NullTranslationHelper) + _, handler := GetIssueComments(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index c5f9d9fae..14aeb9187 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -14,7 +14,7 @@ import ( ) // GetPullRequest creates a tool to get details of a specific pull request. -func GetPullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_DESCRIPTION", "Get details of a specific pull request")), mcp.WithString("owner", @@ -44,6 +44,10 @@ func GetPullRequest(client *github.Client, t translations.TranslationHelperFunc) return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { return nil, fmt.Errorf("failed to get pull request: %w", err) @@ -68,7 +72,7 @@ func GetPullRequest(client *github.Client, t translations.TranslationHelperFunc) } // UpdatePullRequest creates a tool to update an existing pull request. -func UpdatePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_DESCRIPTION", "Update an existing pull request in a GitHub repository")), mcp.WithString("owner", @@ -157,6 +161,10 @@ func UpdatePullRequest(client *github.Client, t translations.TranslationHelperFu return mcp.NewToolResultError("No update parameters provided."), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } pr, resp, err := client.PullRequests.Edit(ctx, owner, repo, pullNumber, update) if err != nil { return nil, fmt.Errorf("failed to update pull request: %w", err) @@ -181,7 +189,7 @@ func UpdatePullRequest(client *github.Client, t translations.TranslationHelperFu } // ListPullRequests creates a tool to list and filter repository pull requests. -func ListPullRequests(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_pull_requests", mcp.WithDescription(t("TOOL_LIST_PULL_REQUESTS_DESCRIPTION", "List and filter repository pull requests")), mcp.WithString("owner", @@ -255,6 +263,10 @@ func ListPullRequests(client *github.Client, t translations.TranslationHelperFun }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } prs, resp, err := client.PullRequests.List(ctx, owner, repo, opts) if err != nil { return nil, fmt.Errorf("failed to list pull requests: %w", err) @@ -279,7 +291,7 @@ func ListPullRequests(client *github.Client, t translations.TranslationHelperFun } // MergePullRequest creates a tool to merge a pull request. -func MergePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("merge_pull_request", mcp.WithDescription(t("TOOL_MERGE_PULL_REQUEST_DESCRIPTION", "Merge a pull request")), mcp.WithString("owner", @@ -335,6 +347,10 @@ func MergePullRequest(client *github.Client, t translations.TranslationHelperFun MergeMethod: mergeMethod, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } result, resp, err := client.PullRequests.Merge(ctx, owner, repo, pullNumber, commitMessage, options) if err != nil { return nil, fmt.Errorf("failed to merge pull request: %w", err) @@ -359,7 +375,7 @@ func MergePullRequest(client *github.Client, t translations.TranslationHelperFun } // GetPullRequestFiles creates a tool to get the list of files changed in a pull request. -func GetPullRequestFiles(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_files", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_FILES_DESCRIPTION", "Get the list of files changed in a pull request")), mcp.WithString("owner", @@ -389,6 +405,10 @@ func GetPullRequestFiles(client *github.Client, t translations.TranslationHelper return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } opts := &github.ListOptions{} files, resp, err := client.PullRequests.ListFiles(ctx, owner, repo, pullNumber, opts) if err != nil { @@ -414,7 +434,7 @@ func GetPullRequestFiles(client *github.Client, t translations.TranslationHelper } // GetPullRequestStatus creates a tool to get the combined status of all status checks for a pull request. -func GetPullRequestStatus(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_status", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_STATUS_DESCRIPTION", "Get the combined status of all status checks for a pull request")), mcp.WithString("owner", @@ -444,6 +464,10 @@ func GetPullRequestStatus(client *github.Client, t translations.TranslationHelpe return mcp.NewToolResultError(err.Error()), nil } // First get the PR to find the head SHA + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { return nil, fmt.Errorf("failed to get pull request: %w", err) @@ -483,7 +507,7 @@ func GetPullRequestStatus(client *github.Client, t translations.TranslationHelpe } // UpdatePullRequestBranch creates a tool to update a pull request branch with the latest changes from the base branch. -func UpdatePullRequestBranch(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("update_pull_request_branch", mcp.WithDescription(t("TOOL_UPDATE_PULL_REQUEST_BRANCH_DESCRIPTION", "Update a pull request branch with the latest changes from the base branch")), mcp.WithString("owner", @@ -524,6 +548,10 @@ func UpdatePullRequestBranch(client *github.Client, t translations.TranslationHe opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } result, resp, err := client.PullRequests.UpdateBranch(ctx, owner, repo, pullNumber, opts) if err != nil { // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, @@ -553,7 +581,7 @@ func UpdatePullRequestBranch(client *github.Client, t translations.TranslationHe } // GetPullRequestComments creates a tool to get the review comments on a pull request. -func GetPullRequestComments(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_comments", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_COMMENTS_DESCRIPTION", "Get the review comments on a pull request")), mcp.WithString("owner", @@ -589,6 +617,10 @@ func GetPullRequestComments(client *github.Client, t translations.TranslationHel }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } comments, resp, err := client.PullRequests.ListComments(ctx, owner, repo, pullNumber, opts) if err != nil { return nil, fmt.Errorf("failed to get pull request comments: %w", err) @@ -613,7 +645,7 @@ func GetPullRequestComments(client *github.Client, t translations.TranslationHel } // GetPullRequestReviews creates a tool to get the reviews on a pull request. -func GetPullRequestReviews(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_pull_request_reviews", mcp.WithDescription(t("TOOL_GET_PULL_REQUEST_REVIEWS_DESCRIPTION", "Get the reviews on a pull request")), mcp.WithString("owner", @@ -643,6 +675,10 @@ func GetPullRequestReviews(client *github.Client, t translations.TranslationHelp return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } reviews, resp, err := client.PullRequests.ListReviews(ctx, owner, repo, pullNumber, nil) if err != nil { return nil, fmt.Errorf("failed to get pull request reviews: %w", err) @@ -667,7 +703,7 @@ func GetPullRequestReviews(client *github.Client, t translations.TranslationHelp } // CreatePullRequestReview creates a tool to submit a review on a pull request. -func CreatePullRequestReview(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func CreatePullRequestReview(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_pull_request_review", mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_REVIEW_DESCRIPTION", "Create a review on a pull request")), mcp.WithString("owner", @@ -835,6 +871,10 @@ func CreatePullRequestReview(client *github.Client, t translations.TranslationHe reviewRequest.Comments = comments } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } review, resp, err := client.PullRequests.CreateReview(ctx, owner, repo, pullNumber, reviewRequest) if err != nil { return nil, fmt.Errorf("failed to create pull request review: %w", err) @@ -859,7 +899,7 @@ func CreatePullRequestReview(client *github.Client, t translations.TranslationHe } // CreatePullRequest creates a tool to create a new pull request. -func CreatePullRequest(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_pull_request", mcp.WithDescription(t("TOOL_CREATE_PULL_REQUEST_DESCRIPTION", "Create a new pull request in a GitHub repository")), mcp.WithString("owner", @@ -942,6 +982,10 @@ func CreatePullRequest(client *github.Client, t translations.TranslationHelperFu newPR.Draft = github.Ptr(draft) newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } pr, resp, err := client.PullRequests.Create(ctx, owner, repo, newPR) if err != nil { return nil, fmt.Errorf("failed to create pull request: %w", err) diff --git a/pkg/github/pullrequests_test.go b/pkg/github/pullrequests_test.go index e9647029d..3c20dfc2c 100644 --- a/pkg/github/pullrequests_test.go +++ b/pkg/github/pullrequests_test.go @@ -17,7 +17,7 @@ import ( func Test_GetPullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetPullRequest(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request", tool.Name) assert.NotEmpty(t, tool.Description) @@ -94,7 +94,7 @@ func Test_GetPullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequest(client, translations.NullTranslationHelper) + _, handler := GetPullRequest(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -129,7 +129,7 @@ func Test_GetPullRequest(t *testing.T) { func Test_UpdatePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequest(mockClient, translations.NullTranslationHelper) + tool, _ := UpdatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "update_pull_request", tool.Name) assert.NotEmpty(t, tool.Description) @@ -257,7 +257,7 @@ func Test_UpdatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequest(client, translations.NullTranslationHelper) + _, handler := UpdatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -311,7 +311,7 @@ func Test_UpdatePullRequest(t *testing.T) { func Test_ListPullRequests(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := ListPullRequests(mockClient, translations.NullTranslationHelper) + tool, _ := ListPullRequests(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_pull_requests", tool.Name) assert.NotEmpty(t, tool.Description) @@ -403,7 +403,7 @@ func Test_ListPullRequests(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListPullRequests(client, translations.NullTranslationHelper) + _, handler := ListPullRequests(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -441,7 +441,7 @@ func Test_ListPullRequests(t *testing.T) { func Test_MergePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := MergePullRequest(mockClient, translations.NullTranslationHelper) + tool, _ := MergePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "merge_pull_request", tool.Name) assert.NotEmpty(t, tool.Description) @@ -518,7 +518,7 @@ func Test_MergePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := MergePullRequest(client, translations.NullTranslationHelper) + _, handler := MergePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -552,7 +552,7 @@ func Test_MergePullRequest(t *testing.T) { func Test_GetPullRequestFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetPullRequestFiles(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequestFiles(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request_files", tool.Name) assert.NotEmpty(t, tool.Description) @@ -630,7 +630,7 @@ func Test_GetPullRequestFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequestFiles(client, translations.NullTranslationHelper) + _, handler := GetPullRequestFiles(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -668,7 +668,7 @@ func Test_GetPullRequestFiles(t *testing.T) { func Test_GetPullRequestStatus(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetPullRequestStatus(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequestStatus(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request_status", tool.Name) assert.NotEmpty(t, tool.Description) @@ -790,7 +790,7 @@ func Test_GetPullRequestStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequestStatus(client, translations.NullTranslationHelper) + _, handler := GetPullRequestStatus(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -829,7 +829,7 @@ func Test_GetPullRequestStatus(t *testing.T) { func Test_UpdatePullRequestBranch(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := UpdatePullRequestBranch(mockClient, translations.NullTranslationHelper) + tool, _ := UpdatePullRequestBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "update_pull_request_branch", tool.Name) assert.NotEmpty(t, tool.Description) @@ -917,7 +917,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := UpdatePullRequestBranch(client, translations.NullTranslationHelper) + _, handler := UpdatePullRequestBranch(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -945,7 +945,7 @@ func Test_UpdatePullRequestBranch(t *testing.T) { func Test_GetPullRequestComments(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetPullRequestComments(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequestComments(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request_comments", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1033,7 +1033,7 @@ func Test_GetPullRequestComments(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequestComments(client, translations.NullTranslationHelper) + _, handler := GetPullRequestComments(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1072,7 +1072,7 @@ func Test_GetPullRequestComments(t *testing.T) { func Test_GetPullRequestReviews(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetPullRequestReviews(mockClient, translations.NullTranslationHelper) + tool, _ := GetPullRequestReviews(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_pull_request_reviews", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1156,7 +1156,7 @@ func Test_GetPullRequestReviews(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetPullRequestReviews(client, translations.NullTranslationHelper) + _, handler := GetPullRequestReviews(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1195,7 +1195,7 @@ func Test_GetPullRequestReviews(t *testing.T) { func Test_CreatePullRequestReview(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := CreatePullRequestReview(mockClient, translations.NullTranslationHelper) + tool, _ := CreatePullRequestReview(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_pull_request_review", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1523,7 +1523,7 @@ func Test_CreatePullRequestReview(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreatePullRequestReview(client, translations.NullTranslationHelper) + _, handler := CreatePullRequestReview(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -1566,7 +1566,7 @@ func Test_CreatePullRequestReview(t *testing.T) { func Test_CreatePullRequest(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := CreatePullRequest(mockClient, translations.NullTranslationHelper) + tool, _ := CreatePullRequest(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_pull_request", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1678,7 +1678,7 @@ func Test_CreatePullRequest(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreatePullRequest(client, translations.NullTranslationHelper) + _, handler := CreatePullRequest(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 2dafd4cee..f52c03414 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -14,7 +14,7 @@ import ( ) // ListCommits creates a tool to get commits of a branch in a repository. -func ListCommits(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("list_commits", mcp.WithDescription(t("TOOL_LIST_COMMITS_DESCRIPTION", "Get list of commits of a branch in a GitHub repository")), mcp.WithString("owner", @@ -56,6 +56,10 @@ func ListCommits(client *github.Client, t translations.TranslationHelperFunc) (t }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } commits, resp, err := client.Repositories.ListCommits(ctx, owner, repo, opts) if err != nil { return nil, fmt.Errorf("failed to list commits: %w", err) @@ -80,7 +84,7 @@ func ListCommits(client *github.Client, t translations.TranslationHelperFunc) (t } // CreateOrUpdateFile creates a tool to create or update a file in a GitHub repository. -func CreateOrUpdateFile(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_or_update_file", mcp.WithDescription(t("TOOL_CREATE_OR_UPDATE_FILE_DESCRIPTION", "Create or update a single file in a GitHub repository")), mcp.WithString("owner", @@ -157,6 +161,10 @@ func CreateOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // Create or update the file + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } fileContent, resp, err := client.Repositories.CreateFile(ctx, owner, repo, path, opts) if err != nil { return nil, fmt.Errorf("failed to create/update file: %w", err) @@ -181,7 +189,7 @@ func CreateOrUpdateFile(client *github.Client, t translations.TranslationHelperF } // CreateRepository creates a tool to create a new GitHub repository. -func CreateRepository(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_repository", mcp.WithDescription(t("TOOL_CREATE_REPOSITORY_DESCRIPTION", "Create a new GitHub repository in your account")), mcp.WithString("name", @@ -223,6 +231,10 @@ func CreateRepository(client *github.Client, t translations.TranslationHelperFun AutoInit: github.Ptr(autoInit), } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } createdRepo, resp, err := client.Repositories.Create(ctx, "", repo) if err != nil { return nil, fmt.Errorf("failed to create repository: %w", err) @@ -247,7 +259,7 @@ func CreateRepository(client *github.Client, t translations.TranslationHelperFun } // GetFileContents creates a tool to get the contents of a file or directory from a GitHub repository. -func GetFileContents(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetFileContents(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_file_contents", mcp.WithDescription(t("TOOL_GET_FILE_CONTENTS_DESCRIPTION", "Get the contents of a file or directory from a GitHub repository")), mcp.WithString("owner", @@ -284,6 +296,10 @@ func GetFileContents(client *github.Client, t translations.TranslationHelperFunc return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } opts := &github.RepositoryContentGetOptions{Ref: branch} fileContent, dirContent, resp, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) if err != nil { @@ -316,7 +332,7 @@ func GetFileContents(client *github.Client, t translations.TranslationHelperFunc } // ForkRepository creates a tool to fork a repository. -func ForkRepository(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("fork_repository", mcp.WithDescription(t("TOOL_FORK_REPOSITORY_DESCRIPTION", "Fork a GitHub repository to your account or specified organization")), mcp.WithString("owner", @@ -350,6 +366,10 @@ func ForkRepository(client *github.Client, t translations.TranslationHelperFunc) opts.Organization = org } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } forkedRepo, resp, err := client.Repositories.CreateFork(ctx, owner, repo, opts) if err != nil { // Check if it's an acceptedError. An acceptedError indicates that the update is in progress, @@ -379,7 +399,7 @@ func ForkRepository(client *github.Client, t translations.TranslationHelperFunc) } // CreateBranch creates a tool to create a new branch. -func CreateBranch(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("create_branch", mcp.WithDescription(t("TOOL_CREATE_BRANCH_DESCRIPTION", "Create a new branch in a GitHub repository")), mcp.WithString("owner", @@ -416,6 +436,11 @@ func CreateBranch(client *github.Client, t translations.TranslationHelperFunc) ( return mcp.NewToolResultError(err.Error()), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + // Get the source branch SHA var ref *github.Reference @@ -459,7 +484,7 @@ func CreateBranch(client *github.Client, t translations.TranslationHelperFunc) ( } // PushFiles creates a tool to push multiple files in a single commit to a GitHub repository. -func PushFiles(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("push_files", mcp.WithDescription(t("TOOL_PUSH_FILES_DESCRIPTION", "Push multiple files to a GitHub repository in a single commit")), mcp.WithString("owner", @@ -523,6 +548,11 @@ func PushFiles(client *github.Client, t translations.TranslationHelperFunc) (too return mcp.NewToolResultError("files parameter must be an array of objects with path and content"), nil } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + // Get the reference for the branch ref, resp, err := client.Git.GetRef(ctx, owner, repo, "refs/heads/"+branch) if err != nil { diff --git a/pkg/github/repositories_test.go b/pkg/github/repositories_test.go index 5c47183d0..2dc0cff96 100644 --- a/pkg/github/repositories_test.go +++ b/pkg/github/repositories_test.go @@ -18,7 +18,7 @@ import ( func Test_GetFileContents(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := GetFileContents(mockClient, translations.NullTranslationHelper) + tool, _ := GetFileContents(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_file_contents", tool.Name) assert.NotEmpty(t, tool.Description) @@ -132,7 +132,7 @@ func Test_GetFileContents(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetFileContents(client, translations.NullTranslationHelper) + _, handler := GetFileContents(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := mcp.CallToolRequest{ @@ -189,7 +189,7 @@ func Test_GetFileContents(t *testing.T) { func Test_ForkRepository(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := ForkRepository(mockClient, translations.NullTranslationHelper) + tool, _ := ForkRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "fork_repository", tool.Name) assert.NotEmpty(t, tool.Description) @@ -259,7 +259,7 @@ func Test_ForkRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ForkRepository(client, translations.NullTranslationHelper) + _, handler := ForkRepository(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -287,7 +287,7 @@ func Test_ForkRepository(t *testing.T) { func Test_CreateBranch(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := CreateBranch(mockClient, translations.NullTranslationHelper) + tool, _ := CreateBranch(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_branch", tool.Name) assert.NotEmpty(t, tool.Description) @@ -445,7 +445,7 @@ func Test_CreateBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateBranch(client, translations.NullTranslationHelper) + _, handler := CreateBranch(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -478,7 +478,7 @@ func Test_CreateBranch(t *testing.T) { func Test_ListCommits(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := ListCommits(mockClient, translations.NullTranslationHelper) + tool, _ := ListCommits(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "list_commits", tool.Name) assert.NotEmpty(t, tool.Description) @@ -614,7 +614,7 @@ func Test_ListCommits(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := ListCommits(client, translations.NullTranslationHelper) + _, handler := ListCommits(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -652,7 +652,7 @@ func Test_ListCommits(t *testing.T) { func Test_CreateOrUpdateFile(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := CreateOrUpdateFile(mockClient, translations.NullTranslationHelper) + tool, _ := CreateOrUpdateFile(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_or_update_file", tool.Name) assert.NotEmpty(t, tool.Description) @@ -775,7 +775,7 @@ func Test_CreateOrUpdateFile(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateOrUpdateFile(client, translations.NullTranslationHelper) + _, handler := CreateOrUpdateFile(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -815,7 +815,7 @@ func Test_CreateOrUpdateFile(t *testing.T) { func Test_CreateRepository(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := CreateRepository(mockClient, translations.NullTranslationHelper) + tool, _ := CreateRepository(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "create_repository", tool.Name) assert.NotEmpty(t, tool.Description) @@ -923,7 +923,7 @@ func Test_CreateRepository(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := CreateRepository(client, translations.NullTranslationHelper) + _, handler := CreateRepository(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -961,7 +961,7 @@ func Test_CreateRepository(t *testing.T) { func Test_PushFiles(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := PushFiles(mockClient, translations.NullTranslationHelper) + tool, _ := PushFiles(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "push_files", tool.Name) assert.NotEmpty(t, tool.Description) @@ -1256,7 +1256,7 @@ func Test_PushFiles(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := PushFiles(client, translations.NullTranslationHelper) + _, handler := PushFiles(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index 47cb8bf64..949157f55 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -18,52 +18,52 @@ import ( ) // GetRepositoryResourceContent defines the resource template and handler for getting repository content. -func GetRepositoryResourceContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +func GetRepositoryResourceContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_DESCRIPTION", "Repository Content"), ), - RepositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } // GetRepositoryResourceBranchContent defines the resource template and handler for getting repository content for a branch. -func GetRepositoryResourceBranchContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +func GetRepositoryResourceBranchContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/refs/heads/{branch}/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_BRANCH_DESCRIPTION", "Repository Content for specific branch"), ), - RepositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } // GetRepositoryResourceCommitContent defines the resource template and handler for getting repository content for a commit. -func GetRepositoryResourceCommitContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +func GetRepositoryResourceCommitContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/sha/{sha}/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_COMMIT_DESCRIPTION", "Repository Content for specific commit"), ), - RepositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } // GetRepositoryResourceTagContent defines the resource template and handler for getting repository content for a tag. -func GetRepositoryResourceTagContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +func GetRepositoryResourceTagContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/refs/tags/{tag}/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_TAG_DESCRIPTION", "Repository Content for specific tag"), ), - RepositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } // GetRepositoryResourcePrContent defines the resource template and handler for getting repository content for a pull request. -func GetRepositoryResourcePrContent(client *github.Client, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { +func GetRepositoryResourcePrContent(getClient GetClientFn, t translations.TranslationHelperFunc) (mcp.ResourceTemplate, server.ResourceTemplateHandlerFunc) { return mcp.NewResourceTemplate( "repo://{owner}/{repo}/refs/pull/{prNumber}/head/contents{/path*}", // Resource template t("RESOURCE_REPOSITORY_CONTENT_PR_DESCRIPTION", "Repository Content for specific pull request"), ), - RepositoryResourceContentsHandler(client) + RepositoryResourceContentsHandler(getClient) } // RepositoryResourceContentsHandler returns a handler function for repository content requests. -func RepositoryResourceContentsHandler(client *github.Client) func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { +func RepositoryResourceContentsHandler(getClient GetClientFn) func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { // the matcher will give []string with one element // https://github.com/mark3labs/mcp-go/pull/54 @@ -107,6 +107,10 @@ func RepositoryResourceContentsHandler(client *github.Client) func(ctx context.C opts.Ref = "refs/pull/" + prNumber[0] + "/head" } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } fileContent, directoryContent, _, err := client.Repositories.GetContents(ctx, owner, repo, path, opts) if err != nil { return nil, err diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index c274d1b53..ffd14be32 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -234,7 +234,7 @@ func Test_repositoryResourceContentsHandler(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { client := github.NewClient(tc.mockedClient) - handler := RepositoryResourceContentsHandler(client) + handler := RepositoryResourceContentsHandler((stubGetClientFn(client))) request := mcp.ReadResourceRequest{ Params: struct { diff --git a/pkg/github/search.go b/pkg/github/search.go index cd2ab4346..75810e245 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -13,7 +13,7 @@ import ( ) // SearchRepositories creates a tool to search for GitHub repositories. -func SearchRepositories(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_repositories", mcp.WithDescription(t("TOOL_SEARCH_REPOSITORIES_DESCRIPTION", "Search for GitHub repositories")), mcp.WithString("query", @@ -39,6 +39,10 @@ func SearchRepositories(client *github.Client, t translations.TranslationHelperF }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } result, resp, err := client.Search.Repositories(ctx, query, opts) if err != nil { return nil, fmt.Errorf("failed to search repositories: %w", err) @@ -63,7 +67,7 @@ func SearchRepositories(client *github.Client, t translations.TranslationHelperF } // SearchCode creates a tool to search for code across GitHub repositories. -func SearchCode(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_code", mcp.WithDescription(t("TOOL_SEARCH_CODE_DESCRIPTION", "Search for code across GitHub repositories")), mcp.WithString("q", @@ -106,6 +110,11 @@ func SearchCode(client *github.Client, t translations.TranslationHelperFunc) (to }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.Search.Code(ctx, query, opts) if err != nil { return nil, fmt.Errorf("failed to search code: %w", err) @@ -130,7 +139,7 @@ func SearchCode(client *github.Client, t translations.TranslationHelperFunc) (to } // SearchUsers creates a tool to search for GitHub users. -func SearchUsers(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("search_users", mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users")), mcp.WithString("q", @@ -174,6 +183,11 @@ func SearchUsers(client *github.Client, t translations.TranslationHelperFunc) (t }, } + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + result, resp, err := client.Search.Users(ctx, query, opts) if err != nil { return nil, fmt.Errorf("failed to search users: %w", err) diff --git a/pkg/github/search_test.go b/pkg/github/search_test.go index b000a0bfb..b61518e47 100644 --- a/pkg/github/search_test.go +++ b/pkg/github/search_test.go @@ -16,7 +16,7 @@ import ( func Test_SearchRepositories(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := SearchRepositories(mockClient, translations.NullTranslationHelper) + tool, _ := SearchRepositories(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "search_repositories", tool.Name) assert.NotEmpty(t, tool.Description) @@ -122,7 +122,7 @@ func Test_SearchRepositories(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchRepositories(client, translations.NullTranslationHelper) + _, handler := SearchRepositories(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -163,7 +163,7 @@ func Test_SearchRepositories(t *testing.T) { func Test_SearchCode(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := SearchCode(mockClient, translations.NullTranslationHelper) + tool, _ := SearchCode(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "search_code", tool.Name) assert.NotEmpty(t, tool.Description) @@ -273,7 +273,7 @@ func Test_SearchCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchCode(client, translations.NullTranslationHelper) + _, handler := SearchCode(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) @@ -314,7 +314,7 @@ func Test_SearchCode(t *testing.T) { func Test_SearchUsers(t *testing.T) { // Verify tool definition once mockClient := github.NewClient(nil) - tool, _ := SearchUsers(mockClient, translations.NullTranslationHelper) + tool, _ := SearchUsers(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "search_users", tool.Name) assert.NotEmpty(t, tool.Description) @@ -428,7 +428,7 @@ func Test_SearchUsers(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := SearchUsers(client, translations.NullTranslationHelper) + _, handler := SearchUsers(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs) diff --git a/pkg/github/server.go b/pkg/github/server.go index 80457a54f..9dee1596c 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -14,8 +14,10 @@ import ( "github.com/mark3labs/mcp-go/server" ) +type GetClientFn func(context.Context) (*github.Client, error) + // NewServer creates a new GitHub MCP server with the specified GH client and logger. -func NewServer(client *github.Client, version string, readOnly bool, t translations.TranslationHelperFunc) *server.MCPServer { +func NewServer(getClient GetClientFn, version string, readOnly bool, t translations.TranslationHelperFunc) *server.MCPServer { // Create a new MCP server s := server.NewMCPServer( "github-mcp-server", @@ -24,65 +26,65 @@ func NewServer(client *github.Client, version string, readOnly bool, t translati server.WithLogging()) // Add GitHub Resources - s.AddResourceTemplate(GetRepositoryResourceContent(client, t)) - s.AddResourceTemplate(GetRepositoryResourceBranchContent(client, t)) - s.AddResourceTemplate(GetRepositoryResourceCommitContent(client, t)) - s.AddResourceTemplate(GetRepositoryResourceTagContent(client, t)) - s.AddResourceTemplate(GetRepositoryResourcePrContent(client, t)) + s.AddResourceTemplate(GetRepositoryResourceContent(getClient, t)) + s.AddResourceTemplate(GetRepositoryResourceBranchContent(getClient, t)) + s.AddResourceTemplate(GetRepositoryResourceCommitContent(getClient, t)) + s.AddResourceTemplate(GetRepositoryResourceTagContent(getClient, t)) + s.AddResourceTemplate(GetRepositoryResourcePrContent(getClient, t)) // Add GitHub tools - Issues - s.AddTool(GetIssue(client, t)) - s.AddTool(SearchIssues(client, t)) - s.AddTool(ListIssues(client, t)) - s.AddTool(GetIssueComments(client, t)) + s.AddTool(GetIssue(getClient, t)) + s.AddTool(SearchIssues(getClient, t)) + s.AddTool(ListIssues(getClient, t)) + s.AddTool(GetIssueComments(getClient, t)) if !readOnly { - s.AddTool(CreateIssue(client, t)) - s.AddTool(AddIssueComment(client, t)) - s.AddTool(UpdateIssue(client, t)) + s.AddTool(CreateIssue(getClient, t)) + s.AddTool(AddIssueComment(getClient, t)) + s.AddTool(UpdateIssue(getClient, t)) } // Add GitHub tools - Pull Requests - s.AddTool(GetPullRequest(client, t)) - s.AddTool(ListPullRequests(client, t)) - s.AddTool(GetPullRequestFiles(client, t)) - s.AddTool(GetPullRequestStatus(client, t)) - s.AddTool(GetPullRequestComments(client, t)) - s.AddTool(GetPullRequestReviews(client, t)) + s.AddTool(GetPullRequest(getClient, t)) + s.AddTool(ListPullRequests(getClient, t)) + s.AddTool(GetPullRequestFiles(getClient, t)) + s.AddTool(GetPullRequestStatus(getClient, t)) + s.AddTool(GetPullRequestComments(getClient, t)) + s.AddTool(GetPullRequestReviews(getClient, t)) if !readOnly { - s.AddTool(MergePullRequest(client, t)) - s.AddTool(UpdatePullRequestBranch(client, t)) - s.AddTool(CreatePullRequestReview(client, t)) - s.AddTool(CreatePullRequest(client, t)) - s.AddTool(UpdatePullRequest(client, t)) + s.AddTool(MergePullRequest(getClient, t)) + s.AddTool(UpdatePullRequestBranch(getClient, t)) + s.AddTool(CreatePullRequestReview(getClient, t)) + s.AddTool(CreatePullRequest(getClient, t)) + s.AddTool(UpdatePullRequest(getClient, t)) } // Add GitHub tools - Repositories - s.AddTool(SearchRepositories(client, t)) - s.AddTool(GetFileContents(client, t)) - s.AddTool(ListCommits(client, t)) + s.AddTool(SearchRepositories(getClient, t)) + s.AddTool(GetFileContents(getClient, t)) + s.AddTool(ListCommits(getClient, t)) if !readOnly { - s.AddTool(CreateOrUpdateFile(client, t)) - s.AddTool(CreateRepository(client, t)) - s.AddTool(ForkRepository(client, t)) - s.AddTool(CreateBranch(client, t)) - s.AddTool(PushFiles(client, t)) + s.AddTool(CreateOrUpdateFile(getClient, t)) + s.AddTool(CreateRepository(getClient, t)) + s.AddTool(ForkRepository(getClient, t)) + s.AddTool(CreateBranch(getClient, t)) + s.AddTool(PushFiles(getClient, t)) } // Add GitHub tools - Search - s.AddTool(SearchCode(client, t)) - s.AddTool(SearchUsers(client, t)) + s.AddTool(SearchCode(getClient, t)) + s.AddTool(SearchUsers(getClient, t)) // Add GitHub tools - Users - s.AddTool(GetMe(client, t)) + s.AddTool(GetMe(getClient, t)) // Add GitHub tools - Code Scanning - s.AddTool(GetCodeScanningAlert(client, t)) - s.AddTool(ListCodeScanningAlerts(client, t)) + s.AddTool(GetCodeScanningAlert(getClient, t)) + s.AddTool(ListCodeScanningAlerts(getClient, t)) return s } // GetMe creates a tool to get details of the authenticated user. -func GetMe(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { +func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { return mcp.NewTool("get_me", mcp.WithDescription(t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request include \"me\", \"my\"...")), mcp.WithString("reason", @@ -90,6 +92,10 @@ func GetMe(client *github.Client, t translations.TranslationHelperFunc) (tool mc ), ), func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } user, resp, err := client.Users.Get(ctx, "") if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 979046fc8..3ee9851af 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -15,10 +15,16 @@ import ( "github.com/stretchr/testify/require" ) +func stubGetClientFn(client *github.Client) GetClientFn { + return func(_ context.Context) (*github.Client, error) { + return client, nil + } +} + func Test_GetMe(t *testing.T) { // Verify tool definition mockClient := github.NewClient(nil) - tool, _ := GetMe(mockClient, translations.NullTranslationHelper) + tool, _ := GetMe(stubGetClientFn(mockClient), translations.NullTranslationHelper) assert.Equal(t, "get_me", tool.Name) assert.NotEmpty(t, tool.Description) @@ -96,7 +102,7 @@ func Test_GetMe(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Setup client with mock client := github.NewClient(tc.mockedClient) - _, handler := GetMe(client, translations.NullTranslationHelper) + _, handler := GetMe(stubGetClientFn(client), translations.NullTranslationHelper) // Create call request request := createMCPRequest(tc.requestArgs)