@@ -37,12 +37,30 @@ const (
37
37
X11MaxPort = X11StartPort + X11MaxDisplays
38
38
)
39
39
40
+ // X11Network abstracts the creation of network listeners for X11 forwarding.
41
+ // It is intended mainly for testing; production code uses the default
42
+ // implementation backed by the operating system networking stack.
43
+ type X11Network interface {
44
+ Listen (network , address string ) (net.Listener , error )
45
+ }
46
+
47
+ // osNet is the default X11Network implementation that uses the standard
48
+ // library network stack.
49
+ type osNet struct {}
50
+
51
+ func (osNet ) Listen (network , address string ) (net.Listener , error ) {
52
+ return net .Listen (network , address )
53
+ }
54
+
40
55
type x11Forwarder struct {
41
56
logger slog.Logger
42
57
x11HandlerErrors * prometheus.CounterVec
43
58
fs afero.Fs
44
59
displayOffset int
45
60
61
+ // network creates X11 listener sockets. Defaults to osNet{}.
62
+ network X11Network
63
+
46
64
mu sync.Mutex
47
65
sessions map [* x11Session ]struct {}
48
66
connections map [net.Conn ]struct {}
@@ -147,26 +165,27 @@ func (x *x11Forwarder) listenForConnections(
147
165
x .closeAndRemoveSession (session )
148
166
}
149
167
150
- tcpConn , ok := conn .(* net.TCPConn )
151
- if ! ok {
152
- x .logger .Warn (ctx , fmt .Sprintf ("failed to cast connection to TCPConn. got: %T" , conn ))
153
- _ = conn .Close ()
154
- continue
168
+ var originAddr string
169
+ var originPort uint32
170
+
171
+ if tcpConn , ok := conn .(* net.TCPConn ); ok {
172
+ if tcpAddr , ok := tcpConn .LocalAddr ().(* net.TCPAddr ); ok {
173
+ originAddr = tcpAddr .IP .String ()
174
+ // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
175
+ originPort = uint32 (tcpAddr .Port )
176
+ }
155
177
}
156
- tcpAddr , ok := tcpConn .LocalAddr ().(* net.TCPAddr )
157
- if ! ok {
158
- x .logger .Warn (ctx , fmt .Sprintf ("failed to cast local address to TCPAddr. got: %T" , tcpConn .LocalAddr ()))
159
- _ = conn .Close ()
160
- continue
178
+ // Fallback values for in-memory or non-TCP connections.
179
+ if originAddr == "" {
180
+ originAddr = "127.0.0.1"
161
181
}
162
182
163
183
channel , reqs , err := serverConn .OpenChannel ("x11" , gossh .Marshal (struct {
164
184
OriginatorAddress string
165
185
OriginatorPort uint32
166
186
}{
167
- OriginatorAddress : tcpAddr .IP .String (),
168
- // #nosec G115 - Safe conversion as TCP port numbers are within uint32 range (0-65535)
169
- OriginatorPort : uint32 (tcpAddr .Port ),
187
+ OriginatorAddress : originAddr ,
188
+ OriginatorPort : originPort ,
170
189
}))
171
190
if err != nil {
172
191
x .logger .Warn (ctx , "failed to open X11 channel" , slog .Error (err ))
@@ -287,13 +306,13 @@ func (x *x11Forwarder) evictLeastRecentlyUsedSession() {
287
306
// createX11Listener creates a listener for X11 forwarding, it will use
288
307
// the next available port starting from X11StartPort and displayOffset.
289
308
func (x * x11Forwarder ) createX11Listener (ctx context.Context ) (ln net.Listener , display int , err error ) {
290
- var lc net.ListenConfig
291
309
// Look for an open port to listen on.
292
310
for port := X11StartPort + x .displayOffset ; port <= X11MaxPort ; port ++ {
293
311
if ctx .Err () != nil {
294
312
return nil , - 1 , ctx .Err ()
295
313
}
296
- ln , err = lc .Listen (ctx , "tcp" , fmt .Sprintf ("localhost:%d" , port ))
314
+
315
+ ln , err = x .network .Listen ("tcp" , fmt .Sprintf ("localhost:%d" , port ))
297
316
if err == nil {
298
317
display = port - X11StartPort
299
318
return ln , display , nil
0 commit comments