8000 chore: unit test for X11 eviction · coder/coder@a6ab9a1 · GitHub
[go: up one dir, main page]

Skip to content

Commit a6ab9a1

Browse files
committed
chore: unit test for X11 eviction
1 parent a5bfb20 commit a6ab9a1

File tree

2 files changed

+165
-3
lines changed

2 files changed

+165
-3
lines changed

agent/agentssh/x11.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,17 @@ func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
8383

8484
// x11Handler is called when a session has requested X11 forwarding.
8585
// It listens for X11 connections and forwards them to the client.
86-
func (x *x11Forwarder) x11Handler(ctx ssh.Context, sshSession ssh.Session) (displayNumber int, handled bool) {
86+
func (x *x11Forwarder) x11Handler(sshCtx ssh.Context, sshSession ssh.Session) (displayNumber int, handled bool) {
8787
x11, hasX11 := sshSession.X11()
8888
if !hasX11 {
8989
return -1, false
9090
}
91-
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
91+
serverConn, valid := sshCtx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
9292
if !valid {
93-
x.logger.Warn(ctx, "failed to get server connection")
93+
x.logger.Warn(sshCtx, "failed to get server connection")
9494
return -1, false
9595
}
96+
ctx := slog.With(sshCtx, slog.F("session_id", fmt.Sprintf("%x", serverConn.SessionID())))
9697

9798
hostname, err := os.Hostname()
9899
if err != nil {
@@ -127,6 +128,7 @@ func (x *x11Forwarder) x11Handler(ctx ssh.Context, sshSession ssh.Session) (disp
127128
}()
128129

129130
go x.listenForConnections(ctx, x11session, serverConn, x11)
131+
x.logger.Debug(ctx, "X11 forwarding started", slog.F("display", x11session.display))
130132

131133
return x11session.display, true
132134
}

agent/agentssh/x11_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"bytes"
66
"encoding/hex"
77
"fmt"
8+
"io"
89
"net"
910
"os"
1011
"path/filepath"
@@ -127,3 +128,162 @@ func TestServer_X11(t *testing.T) {
127128
_, err = fs.Stat(filepath.Join(home, ".Xauthority"))
128129
require.NoError(t, err)
129130
}
131+
132+
func TestServer_X11_EvictionLRU(t *testing.T) {
133+
t.Parallel()
134+
if runtime.GOOS != "linux" {
135+
t.Skip("X11 forwarding is only supported on Linux")
136+
}
137+
138+
ctx := testutil.Context(t, testutil.WaitLong)
139+
logger := testutil.Logger(t)
140+
fs := afero.NewMemMapFs()
141+
142+
// Use in-process networking for X11 forwarding.
143+
inproc := testutil.NewInProcNet()
144+
145+
cfg := &agentssh.Config{
146+
X11Net: inproc,
147+
}
148+
149+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
150+
require.NoError(t, err)
151+
defer s.Close()
152+
err = s.UpdateHostSigner(42)
153+
require.NoError(t, err)
154+
155+
ln, err := net.Listen("tcp", "127.0.0.1:0")
156+
require.NoError(t, err)
157+
158+
done := testutil.Go(t, func() {
159+
err := s.Serve(ln)
160+
assert.Error(t, err)
161+
})
162+
163+
c := sshClient(t, ln.Addr().String())
164+
165+
// Calculate how many simultaneous X11 sessions we can create given the
166+
// configured port range.
167+
startPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset
168+
maxSessions := agentssh.X11MaxPort - startPort + 1
169+
require.Greater(t, maxSessions, 0, "expected a positive maxSessions value")
170+
171+
// shellSession holds references to the session and its standard streams so
172+
// that the test can keep them open (and optionally interact with them) for
173+
// the lifetime of the test. If we don't start the Shell with pipes in place,
174+
// the session will be torn down asynchronously during the test.
175+
type shellSession struct {
176+
sess *gossh.Session
177+
stdin io.WriteCloser
178+
stdout io.Reader
179+
stderr io.Reader
180+
// scanner is used to read the output of the session, line by line.
181+
scanner *bufio.Scanner
182+
}
183+
184+
sessions := make([]shellSession, 0, maxSessions)
185+
for i := 0; i < maxSessions; i++ {
186+
sess, err := c.NewSession()
187+
require.NoError(t, err)
188+
189+
_, err = sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
190+
AuthProtocol: "MIT-MAGIC-COOKIE-1",
191+
AuthCookie: hex.EncodeToString([]byte(fmt.Sprintf("cookie%d", i))),
192+
ScreenNumber: uint32(0),
193+
}))
194+
require.NoError(t, err)
195+
196+
stdin, err := sess.StdinPipe()
197+
require.NoError(t, err)
198+
stdout, err := sess.StdoutPipe()
199+
require.NoError(t, err)
200+
stderr, err := sess.StderrPipe()
201+
require.NoError(t, err)
202+
require.NoError(t, sess.Shell())
203+
204+
// The SSH server lazily starts the session. We need to write a command
205+
// and read back to ensure the X11 forwarding is started.
206+
scanner := bufio.NewScanner(stdout)
207+
msg := fmt.Sprintf("ready-%d", i)
208+
_, err = stdin.Write([]byte("echo " + msg + "\n"))
209+
require.NoError(t, err)
210+
// Read until we get the message (first token may be empty due to shell prompt)
211+
for scanner.Scan() {
212+
line := strings.TrimSpace(scanner.Text())
213+
if strings.Contains(line, msg) {
214+
break
215+
}
216+
}
217+
require.NoError(t, scanner.Err())
218+
219+
sessions = append(sessions, shellSession{
220+
sess: sess,
221+
stdin: stdin,
222+
stdout: stdout,
223+
stderr: stderr,
224+
scanner: scanner,
225+
})
226+
}
227+
228+
// Create one more session which should evict the first (LRU) session and
229+
// therefore reuse the very first display/port.
230+
extraSess, err := c.NewSession()
231+
require.NoError(t, err)
232+
233+
_, err = extraSess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
234+
AuthProtocol: "MIT-MAGIC-COOKIE-1",
235+
AuthCookie: hex.EncodeToString([]byte("extra")),
236+
ScreenNumber: uint32(0),
237+
}))
238+
require.NoError(t, err)
239+
240+
// Ask the remote side for the DISPLAY value so we can extract the display
241+
// number that was assigned to this session.
242+
out, err := extraSess.Output("echo DISPLAY=$DISPLAY")
243+
require.NoError(t, err)
244+
245+
// Example output line: "DISPLAY=localhost:10.0".
246+
var newDisplayNumber int
247+
{
248+
sc := bufio.NewScanner(bytes.NewReader(out))
249+
for sc.Scan() {
250+
line := strings.TrimSpace(sc.Text())
251+
if strings.HasPrefix(line, "DISPLAY=") {
252+
parts := strings.SplitN(line, ":", 2)
253+
require.Len(t, parts, 2)
254+
displayPart := parts[1]
255+
if strings.Contains(displayPart, ".") {
256+
displayPart = strings.SplitN(displayPart, ".", 2)[0]
257+
}
258+
var convErr error
259+
newDisplayNumber, convErr = strconv.Atoi(displayPart)
260+
require.NoError(t, convErr)
261+
break
262+
}
263+
}
264+
require.NoError(t, sc.Err())
265+
}
266+
267+
// The display number should have wrapped around to the starting value.
268+
assert.Equal(t, agentssh.X11DefaultDisplayOffset, newDisplayNumber, "expected display number to be reused after eviction")
269+
270+
// validate that the first session was torn down.
271+
_, err = sessions[0].stdin.Write([]byte("echo DISPLAY=$DISPLAY\n"))
272+
require.ErrorIs(t, err, io.EOF)
273+
err = sessions[0].sess.Wait()
274+
require.Error(t, err)
275+
276+
// Cleanup.
277+
for _, sh := range sessions[1:] {
278+
err = sh.stdin.Close()
279+
require.NoError(t, err)
280+
err = sh.sess.Wait()
281+
require.NoError(t, err)
282+
}
283+
err = extraSess.Close()
284+
require.ErrorIs(t, err, io.EOF)
285+
286+
err = s.Close()
287+
require.NoError(t, err)
288+
_ = testutil.TryReceive(ctx, t, done)
289+
}

0 commit comments

Comments
 (0)
0