10000 chore: unit test for X11 eviction · coder/coder@7f2e241 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f2e241

Browse files
committed
chore: unit test for X11 eviction
1 parent b60589d commit 7f2e241

File tree

2 files changed

+167
-3
lines changed

2 files changed

+167
-3
lines changed

agent/agentssh/x11.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,17 @@ func (*Server) x11Callback(_ ssh.Context, _ ssh.X11) bool {
8484

8585
// x11Handler is called when a session has requested X11 forwarding.
8686
// It listens for X11 connections and forwards them to the client.
87-
func (x *x11Forwarder) x11Handler(ctx ssh.Context, sshSession ssh.Session) (displayNumber int, handled bool) {
87+
func (x *x11Forwarder) x11Handler(sshCtx ssh.Context, sshSession ssh.Session) (displayNumber int, handled bool) {
8888
x11, hasX11 := sshSession.X11()
8989
if !hasX11 {
9090
return -1, false
9191
}
92-
serverConn, valid := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
92+
serverConn, valid := sshCtx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
9393
if !valid {
94-
x.logger.Warn(ctx, "failed to get server connection")
94+
x.logger.Warn(sshCtx, "failed to get server connection")
9595
return -1, false
9696
}
97+
ctx := slog.With(sshCtx, slog.F("session_id", fmt.Sprintf("%x", serverConn.SessionID())))
9798

9899
hostname, err := os.Hostname()
99100
if err != nil {
@@ -128,6 +129,7 @@ func (x *x11Forwarder) x11Handler(ctx ssh.Context, sshSession ssh.Session) (disp
128129
}()
129130

130131
go x.listenForConnections(ctx, x11session, serverConn, x11)
132+
x.logger.Debug(ctx, "X11 forwarding started", slog.F("display", x11session.display))
131133

132134
return x11session.display, true
133135
}

agent/agentssh/x11_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"bytes"
66
"encoding/hex"
77
"fmt"
8+
"io"
89
"net"
910
"os"
1011
"path/filepath"
@@ -127,3 +128,164 @@ func TestServer_X11(t *testing.T) {
127128
_, err = fs.Stat(filepath.Join(home, ".Xauthority"))
128129
require.NoError(t, err)
129130
}
131+
132+
func TestServer_X11_EvictionLRU(t *testing.T) {
133+
t.Parallel()
134+
if runtime.GOOS != "linux" {
135+
t.Skip("X11 forwarding is only supported on Linux")
136+
}
137+
138+
ctx := testutil.Context(t, testutil.WaitLong)
139+
logger := testutil.Logger(t)
140+
fs := afero.NewMemMapFs()
141+
142+
// Use in-process networking for X11 forwarding.
143+
inproc := testutil.NewInProcNet()
144+
145+
cfg := &agentssh.Config{
146+
X11Net: inproc,
147+
}
148+
149+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, cfg)
150+
require.NoError(t, err)
151+
defer s.Close()
152+
err = s.UpdateHostSigner(42)
153+
require.NoError(t, err)
154+
155+
ln, err := net.Listen("tcp", "127.0.0.1:0")
156+
require.NoError(t, err)
157+
158+
done := make(chan struct{})
159+
go func() {
160+
defer close(done)
161+
err := s.Serve(ln)
162+
assert.Error(t, err) // Server is closed once we call s.Close().
163+
}()
164+
165+
c := sshClient(t, ln.Addr().String())
166+
167+
// Calculate how many simultaneous X11 sessions we can create given the
168+
// configured port range.
169+
startPort := agentssh.X11StartPort + agentssh.X11DefaultDisplayOffset
170+
maxSessions := agentssh.X11MaxPort - startPort + 1
171+
require.Greater(t, maxSessions, 0, "expected a positive maxSessions value")
172+
173+
// shellSession holds references to the session and its standard streams so
174+
// that the test can keep them open (and optionally interact with them) for
175+
// the lifetime of the test. If we don't start the Shell with pipes in place,
176+
// the session will be torn down asynchronously during the test.
177+
type shellSession struct {
178+
sess *gossh.Session
179+
stdin io.WriteCloser
180+
stdout io.Reader
181+
stderr io.Reader
182+
// scanner is used to read the output of the session, line by line.
183+
scanner *bufio.Scanner
184+
}
185+
186+
sessions := make([]shellSession, 0, maxSessions)
187+
for i := 0; i < maxSessions; i++ {
188+
sess, err := c.NewSession()
189+
require.NoError(t, err)
190+
191+
_, err = sess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
192+
AuthProtocol: "MIT-MAGIC-COOKIE-1",
193+
AuthCookie: hex.EncodeToString([]byte(fmt.Sprintf("cookie%d", i))),
194+
ScreenNumber: uint32(0),
195+
}))
196+
require.NoError(t, err)
197+
198+
stdin, err := sess.StdinPipe()
199+
require.NoError(t, err)
200+
stdout, err := sess.StdoutPipe()
201+
require.NoError(t, err)
202+
stderr, err := sess.StderrPipe()
203+
require.NoError(t, err)
204+
require.NoError(t, sess.Shell())
205+
206+
// The SSH server lazily starts the session. We need to write a command
207+
// and read back to ensure the X11 forwarding is started.
208+
scanner := bufio.NewScanner(stdout)
209+
msg := fmt.Sprintf("ready-%d", i)
210+
_, err = stdin.Write([]byte("echo " + msg + "\n"))
211+
require.NoError(t, err)
212+
// Read until we get the message (first token may be empty due to shell prompt)
213+
for scanner.Scan() {
214+
line := strings.TrimSpace(scanner.Text())
215+
if strings.Contains(line, msg) {
216+
break
217+
}
218+
}
219+
require.NoError(t, scanner.Err())
220+
221+
sessions = append(sessions, shellSession{
222+
sess: sess,
223+
stdin: stdin,
224+
stdout: stdout,
225+
stderr: stderr,
226+
scanner: scanner,
227+
})
228+
}
229+
230+
// Create one more session which should evict the first (LRU) session and
231+
// therefore reuse the very first display/port.
232+
extraSess, err := c.NewSession()
233+
require.NoError(t, err)
234+
235+
_, err = extraSess.SendRequest("x11-req", true, gossh.Marshal(ssh.X11{
236+
AuthProtocol: "MIT-MAGIC-COOKIE-1",
237+
AuthCookie: hex.EncodeToString([]byte("extra")),
238+
ScreenNumber: uint32(0),
239+
}))
240+
require.NoError(t, err)
241+
242+
// Ask the remote side for the DISPLAY value so we can extract the display
243+
// number that was assigned to this session.
244+
out, err := extraSess.Output("echo DISPLAY=$DISPLAY")
245+
require.NoError(t, err)
246+
247+
// Example output line: "DISPLAY=localhost:10.0".
248+
var newDisplayNumber int
249+
{
250+
sc := bufio.NewScanner(bytes.NewReader(out))
251+
for sc.Scan() {
252+
line := strings.TrimSpace(sc.Text())
253+
if strings.HasPrefix(line, "DISPLAY=") {
254+
parts := strings.SplitN(line, ":", 2)
255+
require.Len(t, parts, 2)
256+
displayPart := parts[1]
257+
if strings.Contains(displayPart, ".") {
258+
displayPart = strings.SplitN(displayPart, ".", 2)[0]
259+
}
260+
var convErr error
261+
newDisplayNumber, convErr = strconv.Atoi(displayPart)
262+
require.NoError(t, convErr)
263+
break
264+
}
265+
}
266+
require.NoError(t, sc.Err())
267+
}
268+
269+
// The display number should have wrapped around to the starting value.
270+
assert.Equal(t, agentssh.X11DefaultDisplayOffset, newDisplayNumber, "expected display number to be reused after eviction")
271+
272+
// validate that the first session was torn down.
273+
_, err = sessions[0].stdin.Write([]byte("echo DISPLAY=$DISPLAY\n"))
274+
require.ErrorIs(t, err, io.EOF)
275+
err = sessions[0].sess.Wait()
276+
require.Error(t, err)
277+
278+
// Cleanup.
279+
for _, sh := range sessions[1:] {
280+
err = sh.stdin.Close()
281+
require.NoError(t, err)
282+
err = sh.sess.Wait()
283+
require.NoError(t, err)
284+
}
285+
err = extraSess.Close()
286+
require.ErrorIs(t, err, io.EOF)
287+
288+
err = s.Close()
289+
require.NoError(t, err)
290+
<-done
291+
}

0 commit comments

Comments
 (0)
0