From 6d6377795612490655b4040e3859e8d49a1f444c Mon Sep 17 00:00:00 2001
From: Javier Uruen Val <juruen@github.com>
Date: Mon, 17 Mar 2025 13:29:55 +0100
Subject: [PATCH] add iologging for debugging purposes

---
 cmd/server/main.go | 19 +++++++++++---
 go.mod             |  3 +++
 pkg/log/io.go      | 45 ++++++++++++++++++++++++++++++++
 pkg/log/io_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 129 insertions(+), 3 deletions(-)
 create mode 100644 pkg/log/io.go
 create mode 100644 pkg/log/io_test.go

diff --git a/cmd/server/main.go b/cmd/server/main.go
index a4266a00d..bfe5df2b4 100644
--- a/cmd/server/main.go
+++ b/cmd/server/main.go
@@ -3,12 +3,14 @@ package main
 import (
 	"context"
 	"fmt"
+	"io"
 	stdlog "log"
 	"os"
 	"os/signal"
 	"syscall"
 
 	"github.com/github/github-mcp-server/pkg/github"
+	iolog "github.com/github/github-mcp-server/pkg/log"
 	gogithub "github.com/google/go-github/v69/github"
 	"github.com/mark3labs/mcp-go/server"
 	log "github.com/sirupsen/logrus"
@@ -33,7 +35,8 @@ var (
 			if err != nil {
 				stdlog.Fatal("Failed to initialize logger:", err)
 			}
-			if err := runStdioServer(logger); err != nil {
+			logCommands := viper.GetBool("enable-command-logging")
+			if err := runStdioServer(logger, logCommands); err != nil {
 				stdlog.Fatal("failed to run stdio server:", err)
 			}
 		},
@@ -45,9 +48,11 @@ func init() {
 
 	// Add global flags that will be shared by all commands
 	rootCmd.PersistentFlags().String("log-file", "", "Path to log file")
+	rootCmd.PersistentFlags().Bool("enable-command-logging", false, "When enabled, the server will log all command requests and responses to the log file")
 
 	// Bind flag to viper
 	viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file"))
+	viper.BindPFlag("enable-command-logging", rootCmd.PersistentFlags().Lookup("enable-command-logging"))
 
 	// Add subcommands
 	rootCmd.AddCommand(stdioCmd)
@@ -70,12 +75,13 @@ func initLogger(outPath string) (*log.Logger, error) {
 	}
 
 	logger := log.New()
+	logger.SetLevel(log.DebugLevel)
 	logger.SetOutput(file)
 
 	return logger, nil
 }
 
-func runStdioServer(logger *log.Logger) error {
+func runStdioServer(logger *log.Logger, logCommands bool) error {
 	// Create app context
 	ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
 	defer stop()
@@ -97,7 +103,14 @@ func runStdioServer(logger *log.Logger) error {
 	// Start listening for messages
 	errC := make(chan error, 1)
 	go func() {
-		errC <- stdioServer.Listen(ctx, os.Stdin, os.Stdout)
+		in, out := io.Reader(os.Stdin), io.Writer(os.Stdout)
+
+		if logCommands {
+			loggedIO := iolog.NewIOLogger(in, out, logger)
+			in, out = loggedIO, loggedIO
+		}
+
+		errC <- stdioServer.Listen(ctx, in, out)
 	}()
 
 	// Output github-mcp-server string
diff --git a/go.mod b/go.mod
index e53b8b6b1..6fbe54a43 100644
--- a/go.mod
+++ b/go.mod
@@ -9,10 +9,12 @@ require (
 	github.com/sirupsen/logrus v1.9.3
 	github.com/spf13/cobra v1.9.1
 	github.com/spf13/viper v1.19.0
+	github.com/stretchr/testify v1.9.0
 	golang.org/x/exp v0.0.0-20230905200255-921286631fa9
 )
 
 require (
+	github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
 	github.com/fsnotify/fsnotify v1.7.0 // indirect
 	github.com/google/go-querystring v1.1.0 // indirect
 	github.com/google/uuid v1.6.0 // indirect
@@ -21,6 +23,7 @@ require (
 	github.com/magiconair/properties v1.8.7 // indirect
 	github.com/mitchellh/mapstructure v1.5.0 // indirect
 	github.com/pelletier/go-toml/v2 v2.2.2 // indirect
+	github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
 	github.com/sagikazarmark/locafero v0.4.0 // indirect
 	github.com/sagikazarmark/slog-shim v0.1.0 // indirect
 	github.com/sourcegraph/conc v0.3.0 // indirect
diff --git a/pkg/log/io.go b/pkg/log/io.go
new file mode 100644
index 000000000..de2210278
--- /dev/null
+++ b/pkg/log/io.go
@@ -0,0 +1,45 @@
+package log
+
+import (
+	"io"
+
+	log "github.com/sirupsen/logrus"
+)
+
+// IOLogger is a wrapper around io.Reader and io.Writer that can be used
+// to log the data being read and written from the underlying streams
+type IOLogger struct {
+	reader io.Reader
+	writer io.Writer
+	logger *log.Logger
+}
+
+// NewIOLogger creates a new IOLogger instance
+func NewIOLogger(r io.Reader, w io.Writer, logger *log.Logger) *IOLogger {
+	return &IOLogger{
+		reader: r,
+		writer: w,
+		logger: logger,
+	}
+}
+
+// Read reads data from the underlying io.Reader and logs it.
+func (l *IOLogger) Read(p []byte) (n int, err error) {
+	if l.reader == nil {
+		return 0, io.EOF
+	}
+	n, err = l.reader.Read(p)
+	if n > 0 {
+		l.logger.Infof("[stdin]: received %d bytes: %s", n, string(p[:n]))
+	}
+	return n, err
+}
+
+// Write writes data to the underlying io.Writer and logs it.
+func (l *IOLogger) Write(p []byte) (n int, err error) {
+	if l.writer == nil {
+		return 0, io.ErrClosedPipe
+	}
+	l.logger.Infof("[stdout]: sending %d bytes: %s", len(p), string(p))
+	return l.writer.Write(p)
+}
diff --git a/pkg/log/io_test.go b/pkg/log/io_test.go
new file mode 100644
index 000000000..0d0cd8959
--- /dev/null
+++ b/pkg/log/io_test.go
@@ -0,0 +1,65 @@
+package log
+
+import (
+	"bytes"
+	"strings"
+	"testing"
+
+	log "github.com/sirupsen/logrus"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestLoggedReadWriter(t *testing.T) {
+	t.Run("Read method logs and passes data", func(t *testing.T) {
+		// Setup
+		inputData := "test input data"
+		reader := strings.NewReader(inputData)
+
+		// Create logger with buffer to capture output
+		var logBuffer bytes.Buffer
+		logger := log.New()
+		logger.SetOutput(&logBuffer)
+		logger.SetFormatter(&log.TextFormatter{
+			DisableTimestamp: true,
+		})
+
+		lrw := NewIOLogger(reader, nil, logger)
+
+		// Test Read
+		buf := make([]byte, 100)
+		n, err := lrw.Read(buf)
+
+		// Assertions
+		assert.NoError(t, err)
+		assert.Equal(t, len(inputData), n)
+		assert.Equal(t, inputData, string(buf[:n]))
+		assert.Contains(t, logBuffer.String(), "[stdin]")
+		assert.Contains(t, logBuffer.String(), inputData)
+	})
+
+	t.Run("Write method logs and passes data", func(t *testing.T) {
+		// Setup
+		outputData := "test output data"
+		var writeBuffer bytes.Buffer
+
+		// Create logger with buffer to capture output
+		var logBuffer bytes.Buffer
+		logger := log.New()
+		logger.SetOutput(&logBuffer)
+		logger.SetFormatter(&log.TextFormatter{
+			DisableTimestamp: true,
+		})
+
+		lrw := NewIOLogger(nil, &writeBuffer, logger)
+
+		// Test Write
+		n, err := lrw.Write([]byte(outputData))
+
+		// Assertions
+		assert.NoError(t, err)
+		assert.Equal(t, len(outputData), n)
+		assert.Equal(t, outputData, writeBuffer.String())
+		assert.Contains(t, logBuffer.String(), "[stdout]")
+		assert.Contains(t, logBuffer.String(), outputData)
+	})
+}