8000 fix: Improve agent connection tracking when agent is closed · coder/coder@874eff6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 874eff6

Browse files
committed
fix: Improve agent connection tracking when agent is closed
1 parent ee4f0fc commit 874eff6

File tree

1 file changed

+72
-33
lines changed

1 file changed

+72
-33
lines changed

agent/agent.go

Lines changed: 72 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,27 @@ func (a *agent) run(ctx context.Context) error {
231231
return nil
232232
}
233233

234-
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*tailnet.Conn, error) {
234+
func (a *agent) trackConnGoroutine(fn func()) error {
235+
a.closeMutex.Lock()
236+
defer a.closeMutex.Unlock()
237+
if a.isClosed() {
238+
return xerrors.New("track conn goroutine: agent is closed")
239+
}
240+
a.connCloseWait.Add(1)
241+
go func() {
242+
defer a.connCloseWait.Done()
243+
fn()
244+
}()
245+
return nil
246+
}
247+
248+
func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (network *tailnet.Conn, err error) {
235249
a.closeMutex.Lock()
236250
if a.isClosed() {
237251
a.closeMutex.Unlock()
238252
return nil, xerrors.New("closed")
239253
}
240-
network, err := tailnet.NewConn(&tailnet.Options{
254+
network, err = tailnet.NewConn(&tailnet.Options{
241255
Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)},
242256
DERPMap: derpMap,
243257
Logger: a.logger.Named("tailnet"),
@@ -248,30 +262,39 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
248262
return nil, xerrors.Errorf("create tailnet: %w", err)
249263
}
250264
a.network = network
251-
a.connCloseWait.Add(4)
252265
a.closeMutex.Unlock()
253266

254267
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSSHPort))
255268
if err != nil {
256269
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
257270
}
258-
go func() {
259-
defer a.connCloseWait.Done()
271+
defer func() {
272+
if err != nil {
273+
_ = sshListener.Close()
274+
}
275+
}()
276+
if err = a.trackConnGoroutine(func() {
260277
for {
261278
conn, err := sshListener.Accept()
262279
if err != nil {
263280
return
264281
}
265282
go a.sshServer.HandleConn(conn)
266283
}
267-
}()
284+
}); err != nil {
285+
return nil, err
286+
}
268287

269288
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetReconnectingPTYPort))
270289
if err != nil {
271290
return nil, xerrors.Errorf("listen for reconnecting pty: %w", err)
272291
}
273-
go func() {
274-
defer a.connCloseWait.Done()
292+
defer func() {
293+
if err != nil {
294+
_ = reconnectingPTYListener.Close()
295+
}
296+
}()
297+
if err = a.trackConnGoroutine(func() {
275298
for {
276299
conn, err := reconnectingPTYListener.Accept()
277300
if err != nil {
@@ -298,36 +321,48 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
298321
}
299322
go a.handleReconnectingPTY(ctx, msg, conn)
300323
}
301-
}()
324+
}); err != nil {
325+
return nil, err
326+
}
302327

303328
speedtestListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetSpeedtestPort))
304329
if err != nil {
305330
return nil, xerrors.Errorf("listen for speedtest: %w", err)
306331
}
307-
go func() {
308-
defer a.connCloseWait.Done()
332+
defer func() {
333+
if err != nil {
334+
_ = speedtestListener.Close()
335+
}
336+
}()
337+
if err = a.trackConnGoroutine(func() {
309338
for {
310339
conn, err := speedtestListener.Accept()
311340
if err != nil {
312341
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
313342
return
314343
}
315-
a.closeMutex.Lock()
316-
a.connCloseWait.Add(1)
317-
a.closeMutex.Unlock()
318-
go func() {
319-
defer a.connCloseWait.Done()
344+
if err = a.trackConnGoroutine(func() {
320345
_ = speedtest.ServeConn(conn)
321-
}()
346+
}); err != nil {
347+
a.logger.Debug(ctx, "speedtest listener failed", slog.Error(err))
348+
_ = conn.Close()
349 1E0A +
return
350+
}
322351
}
323-
}()
352+
}); err != nil {
353+
return nil, err
354+
}
324355

325356
statisticsListener, err := network.Listen("tcp", ":"+strconv.Itoa(codersdk.TailnetStatisticsPort))
326357
if err != nil {
327358
return nil, xerrors.Errorf("listen for statistics: %w", err)
328359
}
329-
go func() {
330-
defer a.connCloseWait.Done()
360+
defer func() {
361+
if err != nil {
362+
_ = statisticsListener.Close()
363+
}
364+
}()
365+
if err = a.trackConnGoroutine(func() {
331366
defer statisticsListener.Close()
332367
server := &http.Server{
333368
Handler: a.statisticsHandler(),
@@ -341,11 +376,13 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
341376
_ = server.Close()
342377
}()
343378

344-
err = server.Serve(statisticsListener)
379+
err := server.Serve(statisticsListener)
345380
if err != nil && !xerrors.Is(err, http.ErrServerClosed) && !strings.Contains(err.Error(), "use of closed network connection") {
346381
a.logger.Critical(ctx, "serve statistics HTTP server", slog.Error(err))
347382
}
348-
}()
383+
}); err != nil {
384+
return nil, err
385+
}
349386

350387
return network, nil
351388
}
@@ -527,12 +564,15 @@ func (a *agent) init(ctx context.Context) {
527564
a.logger.Error(ctx, "report stats", slog.Error(err))
528565
return
529566
}
530-
a.connCloseWait.Add(1)
531-
go func() {
532-
defer a.connCloseWait.Done()
567+
568+
if err = a.trackConnGoroutine(func() {
533569
<-a.closed
534-
cl.Close()
535-
}()
570+
_ = cl.Close()
571+
}); err != nil {
572+
a.logger.Error(ctx, "report stats goroutine", slog.Error(err))
573+
_ = cl.Close()
574+
return
575+
}
536576
}
537577

538578
func convertAgentStats(counts map[netlogtype.Connection]netlogtype.Counts) *codersdk.AgentStats {
@@ -787,9 +827,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
787827
return
788828
}
789829

790-
a.closeMutex.Lock()
791-
a.connCloseWait.Add(1)
792-
a.closeMutex.Unlock()
793830
ctx, cancelFunc := context.WithCancel(ctx)
794831
rpty = &reconnectingPTY{
795832
activeConns: map[string]net.Conn{
@@ -818,7 +855,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
818855
_ = process.Wait()
819856
rpty.Close()
820857
}()
821-
go func() {
858+
if err = a.trackConnGoroutine(func() {
822859
buffer := make([]byte, 1024)
823860
for {
824861
read, err := rpty.ptty.Output().Read(buffer)
@@ -846,8 +883,10 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
846883
_ = process.Kill()
847884
rpty.Close()
848885
a.reconnectingPTYs.Delete(msg.ID)
849-
a.connCloseWait.Done()
850-
}()
886+
}); err != nil {
887+
a.logger.Error(ctx, "start reconnecting pty routine", slog.F("id", msg.ID), slog.Error(err))
888+
return
889+
}
851890
}
852891
// Resize the PTY to initial height + width.
853892
err := rpty.ptty.Resize(msg.Height, msg.Width)

0 commit comments

Comments
 (0)
0