@@ -231,13 +231,27 @@ func (a *agent) run(ctx context.Context) error {
231
231
return nil
232
232
}
233
233
234
- func (a * agent ) createTailnet (ctx context.Context , derpMap * tailcfg.DERPMap ) (* tailnet.Conn , error ) {
234
+ func (a * agent ) trackConnGoroutine (fn func ()) error {
235
+ a .closeMutex .Lock ()
236
+ defer a .closeMutex .Unlock ()
237
+ if a .isClosed () {
238
+ return xerrors .New ("track conn goroutine: agent is closed" )
239
+ }
240
+ a .connCloseWait .Add (1 )
241
+ go func () {
242
+ defer a .connCloseWait .Done ()
243
+ fn ()
244
+ }()
245
+ return nil
246
+ }
247
+
248
+ func (a * agent ) createTailnet (ctx context.Context , derpMap * tailcfg.DERPMap ) (network * tailnet.Conn , err error ) {
235
249
a .closeMutex .Lock ()
236
250
if a .isClosed () {
237
251
a .closeMutex .Unlock ()
238
252
return nil , xerrors .New ("closed" )
239
253
}
240
- network , err : = tailnet .NewConn (& tailnet.Options {
254
+ network , err = tailnet .NewConn (& tailnet.Options {
241
255
Addresses : []netip.Prefix {netip .PrefixFrom (codersdk .TailnetIP , 128 )},
242
256
DERPMap : derpMap ,
243
257
Logger : a .logger .Named ("tailnet" ),
@@ -248,30 +262,39 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
248
262
return nil , xerrors .Errorf ("create tailnet: %w" , err )
249
263
}
250
264
a .network = network
251
- a .connCloseWait .Add (4 )
252
265
a .closeMutex .Unlock ()
253
266
254
267
sshListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetSSHPort ))
255
268
if err != nil {
256
269
return nil , xerrors .Errorf ("listen on the ssh port: %w" , err )
257
270
}
258
- go func () {
259
- defer a .connCloseWait .Done ()
271
+ defer func () {
272
+ if err != nil {
273
+ _ = sshListener .Close ()
274
+ }
275
+ }()
276
+ if err = a .trackConnGoroutine (func () {
260
277
for {
261
278
conn , err := sshListener .Accept ()
262
279
if err != nil {
263
280
return
264
281
}
265
282
go a .sshServer .HandleConn (conn )
266
283
}
267
- }()
284
+ }); err != nil {
285
+ return nil , err
286
+ }
268
287
269
288
reconnectingPTYListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetReconnectingPTYPort ))
270
289
if err != nil {
271
290
return nil , xerrors .Errorf ("listen for reconnecting pty: %w" , err )
272
291
}
273
- go func () {
274
- defer a .connCloseWait .Done ()
292
+ defer func () {
293
+ if err != nil {
294
+ _ = reconnectingPTYListener .Close ()
295
+ }
296
+ }()
297
+ if err = a .trackConnGoroutine (func () {
275
298
for {
276
299
conn , err := reconnectingPTYListener .Accept ()
277
300
if err != nil {
@@ -298,36 +321,48 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
298
321
}
299
322
go a .handleReconnectingPTY (ctx , msg , conn )
300
323
}
301
- }()
324
+ }); err != nil {
325
+ return nil , err
326
+ }
302
327
303
328
speedtestListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetSpeedtestPort ))
304
329
if err != nil {
305
330
return nil , xerrors .Errorf ("listen for speedtest: %w" , err )
306
331
}
307
- go func () {
308
- defer a .connCloseWait .Done ()
332
+ defer func () {
333
+ if err != nil {
334
+ _ = speedtestListener .Close ()
335
+ }
336
+ }()
337
+ if err = a .trackConnGoroutine (func () {
309
338
for {
310
339
conn , err := speedtestListener .Accept ()
311
340
if err != nil {
312
341
a .logger .Debug (ctx , "speedtest listener failed" , slog .Error (err ))
313
342
return
314
343
}
315
- a .closeMutex .Lock ()
316
- a .connCloseWait .Add (1 )
317
- a .closeMutex .Unlock ()
318
- go func () {
319
- defer a .connCloseWait .Done ()
344
+ if err = a .trackConnGoroutine (func () {
320
345
_ = speedtest .ServeConn (conn )
321
- }()
346
+ }); err != nil {
347
+ a .logger .Debug (ctx , "speedtest listener failed" , slog .Error (err ))
348
+ _ = conn .Close ()
349
1E0A
+ return
350
+ }
322
351
}
323
- }()
352
+ }); err != nil {
353
+ return nil , err
354
+ }
324
355
325
356
statisticsListener , err := network .Listen ("tcp" , ":" + strconv .Itoa (codersdk .TailnetStatisticsPort ))
326
357
if err != nil {
327
358
return nil , xerrors .Errorf ("listen for statistics: %w" , err )
328
359
}
329
- go func () {
330
- defer a .connCloseWait .Done ()
360
+ defer func () {
361
+ if err != nil {
362
+ _ = statisticsListener .Close ()
363
+ }
364
+ }()
365
+ if err = a .trackConnGoroutine (func () {
331
366
defer statisticsListener .Close ()
332
367
server := & http.Server {
333
368
Handler : a .statisticsHandler (),
@@ -341,11 +376,13 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t
341
376
_ = server .Close ()
342
377
}()
343
378
344
- err = server .Serve (statisticsListener )
379
+ err : = server .Serve (statisticsListener )
345
380
if err != nil && ! xerrors .Is (err , http .ErrServerClosed ) && ! strings .Contains (err .Error (), "use of closed network connection" ) {
346
381
a .logger .Critical (ctx , "serve statistics HTTP server" , slog .Error (err ))
347
382
}
348
- }()
383
+ }); err != nil {
384
+ return nil , err
385
+ }
349
386
350
387
return network , nil
351
388
}
@@ -527,12 +564,15 @@ func (a *agent) init(ctx context.Context) {
527
564
a .logger .Error (ctx , "report stats" , slog .Error (err ))
528
565
return
529
566
}
530
- a .connCloseWait .Add (1 )
531
- go func () {
532
- defer a .connCloseWait .Done ()
567
+
568
+ if err = a .trackConnGoroutine (func () {
533
569
<- a .closed
534
- cl .Close ()
535
- }()
570
+ _ = cl .Close ()
571
+ }); err != nil {
572
+ a .logger .Error (ctx , "report stats goroutine" , slog .Error (err ))
573
+ _ = cl .Close ()
574
+ return
575
+ }
536
576
}
537
577
538
578
func convertAgentStats (counts map [netlogtype.Connection ]netlogtype.Counts ) * codersdk.AgentStats {
@@ -787,9 +827,6 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
787
827
return
788
828
}
789
829
790
- a .closeMutex .Lock ()
791
- a .connCloseWait .Add (1 )
792
- a .closeMutex .Unlock ()
793
830
ctx , cancelFunc := context .WithCancel (ctx )
794
831
rpty = & reconnectingPTY {
795
832
activeConns : map [string ]net.Conn {
@@ -818,7 +855,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
818
855
_ = process .Wait ()
819
856
rpty .Close ()
820
857
}()
821
- go func () {
858
+ if err = a . trackConnGoroutine ( func () {
822
859
buffer := make ([]byte , 1024 )
823
860
for {
824
861
read , err := rpty .ptty .Output ().Read (buffer )
@@ -846,8 +883,10 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, msg codersdk.Reconnec
846
883
_ = process .Kill ()
847
884
rpty .Close ()
848
885
a .reconnectingPTYs .Delete (msg .ID )
849
- a .connCloseWait .Done ()
850
- }()
886
+ }); err != nil {
887
+ a .logger .Error (ctx , "start reconnecting pty routine" , slog .F ("id" , msg .ID ), slog .Error (err ))
888
+ return
889
+ }
851
890
}
852
891
// Resize the PTY to initial height + width.
853
892
err := rpty .ptty .Resize (msg .Height , msg .Width )
0 commit comments