diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index ec682a735c248..6e3760c643cb3 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -117,6 +117,10 @@ type Config struct { // Note that this is different from the devcontainers feature, which uses // subagents. ExperimentalContainers bool + // X11Net allows overriding the networking implementation used for X11 + // forwarding listeners. When nil, a default implementation backed by the + // standard library networking package is used. + X11Net X11Network } type Server struct { @@ -196,6 +200,12 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom displayOffset: *config.X11DisplayOffset, sessions: make(map[*x11Session]struct{}), connections: make(map[net.Conn]struct{}), + network: func() X11Network { + if config.X11Net != nil { + return config.X11Net + } + return osNet{} + }(), }, } diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index 8c23d32bfa5d1..05d9f866c16f6 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -37,12 +37,30 @@ const ( X11MaxPort = X11StartPort + X11MaxDisplays ) +// X11Network abstracts the creation of network listeners for X11 forwarding. +// It is intended mainly for testing; production code uses the default +// implementation backed by the operating system networking stack. +type X11Network interface { + Listen(network, address string) (net.Listener, error) +} + +// osNet is the default X11Network implementation that uses the standard +// library network stack. +type osNet struct{} + +func (osNet) Listen(network, address string) (net.Listener, error) { + return net.Listen(network, address) +} + type x11Forwarder struct { logger slog.Logger x11HandlerErrors *prometheus.CounterVec fs afero.Fs displayOffset int + // network creates X11 listener sockets. Defaults to osNet{}. + network X11Network + mu sync.Mutex sessions map[*x11Session]struct{} connections map[net.Conn]struct{} @@ -147,26 +165,27 @@ func (x *x11Forwarder) listenForConnections( x.closeAndRemoveSession(session) } - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - x.logger.Warn(ctx, fmt.Sprintf("failed to cast connection to TCPConn. got: %T", conn)) - _ = conn.Close() - continue + var originAddr string + var originPort uint32 + + if tcpConn, ok := conn.(*net.TCPConn); ok { + if tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr); ok { + originAddr = tcpAddr.IP.String() + // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) + originPort = uint32(tcpAddr.Port) + } } - tcpAddr, ok := tcpConn.LocalAddr().(*net.TCPAddr) - if !ok { - x.logger.Warn(ctx, fmt.Sprintf("failed to cast local address to TCPAddr. got: %T", tcpConn.LocalAddr())) - _ = conn.Close() - continue + // Fallback values for in-memory or non-TCP connections. + if originAddr == "" { + originAddr = "127.0.0.1" } channel, reqs, err := serverConn.OpenChannel("x11", gossh.Marshal(struct { OriginatorAddress string OriginatorPort uint32 }{ - OriginatorAddress: tcpAddr.IP.String(), - // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535) - OriginatorPort: uint32(tcpAddr.Port), + OriginatorAddress: originAddr, + OriginatorPort: originPort, })) if err != nil { x.logger.Warn(ctx, "failed to open X11 channel", slog.Error(err)) @@ -287,13 +306,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() { // createX11Listener creates a listener for X11 forwarding, it will use // the next available port starting from X11StartPort and displayOffset. func (x *x11Forwarder) createX11Listener(ctx context.Context) (ln net.Listener, display int, err error) { - var lc net.ListenConfig // Look for an open port to listen on. for port := X11StartPort + x.displayOffset; port <= X11MaxPort; port++ { if ctx.Err() != nil { return nil, -1, ctx.Err() } - ln, err = lc.Listen(ctx, "tcp", fmt.Sprintf("localhost:%d", port)) + + ln, err = x.network.Listen("tcp", fmt.Sprintf("localhost:%d", port)) if err == nil { display = port - X11StartPort return ln, display, nil diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go index 39440da7127b8..a680c088de703 100644 --- a/agent/agentssh/x11_test.go +++ b/agent/agentssh/x11_test.go @@ -3,7 +3,6 @@ package agentssh_test import ( "bufio" "bytes" - "context" "encoding/hex" "fmt" "net" @@ -32,10 +31,19 @@ func TestServer_X11(t *testing.T) { t.Skip("X11 forwarding is only supported on Linux") } - ctx := context.Background() + ctx := testutil.Context(t, testutil.WaitShort) logger := testutil.Logger(t) fs := afero.NewMemMapFs() - s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{}) + + // Use in-process networking for X11 forwarding. + inproc := testutil.NewInProcNet() + + // Create server config with custom X11 listener. + cfg := &agentssh.Config{ + X11Net: inproc, + } + + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg) require.NoError(t, err) defer s.Close() err = s.UpdateHostSigner(42) @@ -93,17 +101,15 @@ func TestServer_X11(t *testing.T) { x11Chans := c.HandleChannelOpen("x11") payload := "hello world" - require.Eventually(t, func() bool { - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber)) - if err == nil { - _, err = conn.Write([]byte(payload)) - assert.NoError(t, err) - _ = conn.Close() - } - return err == nil - }, testutil.WaitShort, testutil.IntervalFast) + go func() { + conn, err := inproc.Dial(ctx, testutil.NewAddr("tcp", fmt.Sprintf("localhost:%d", agentssh.X11StartPort+displayNumber))) + assert.NoError(t, err) + _, err = conn.Write([]byte(payload)) + assert.NoError(t, err) + _ = conn.Close() + }() - x11 := <-x11Chans + x11 := testutil.RequireReceive(ctx, t, x11Chans) ch, reqs, err := x11.Accept() require.NoError(t, err) go gossh.DiscardRequests(reqs)