8000 feat(agent): add connection reporting for SSH and reconnecing PTY by mafredri · Pull Request #16652 · coder/coder · GitHub
[go: up one dir, main page]

Skip to content

feat(agent): add connection reporting for SSH and reconnecing PTY #16652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat(agent): add connection reporting for SSH and reconnecing PTY
Updates #15139
  • Loading branch information
mafredri committed Feb 21, 2025
commit 58463c6c9fe207cb8999b877581d90232cbbae41
116 changes: 116 additions & 0 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
"tailscale.com/net/speedtest"
"tailscale.com/tailcfg"
"tailscale.com/types/netlogtype"
Expand Down Expand Up @@ -174,6 +175,7 @@ func New(options Options) Agent {
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
reportConnectionsUpdate: make(chan struct{}, 1),
ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration,
reportMetadataInterval: options.ReportMetadataInterval,
Expand Down Expand Up @@ -247,6 +249,10 @@ type agent struct {
lifecycleStates []agentsdk.PostLifecycleRequest
lifecycleLastReportedIndex int // Keeps track of the last lifecycle state we successfully reported.

reportConnectionsUpdate chan struct{}
reportConnectionsMu sync.Mutex
reportConnections []*proto.ReportConnectionRequest

network *tailnet.Conn
statsReporter *statsReporter
logSender *agentsdk.LogSender
Expand All @@ -272,6 +278,24 @@ func (a *agent) init() {
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
BlockFileTransfer: a.blockFileTransfer,
ReportConnection: func(id uuid.UUID, magicType agentssh.MagicSessionType, ip string) func(code int, reason string) {
var connectionType proto.Connection_Type
switch magicType {
case agentssh.MagicSessionTypeSSH:
connectionType = proto.Connection_SSH
case agentssh.MagicSessionTypeVSCode:
connectionType = proto.Connection_VSCODE
case agentssh.MagicSessionTypeJetBrains:
connectionType = proto.Connection_JETBRAINS
case agentssh.MagicSessionTypeUnknown:
connectionType = proto.Connection_TYPE_UNSPECIFIED
default:
a.logger.Error(a.hardCtx, "unhandled magic session type when reporting connection", slog.F("magic_type", magicType))
connectionType = proto.Connection_TYPE_UNSPECIFIED
}

return a.reportConnection(id, connectionType, ip)
},
})
if err != nil {
panic(err)
Expand All @@ -294,6 +318,9 @@ func (a *agent) init() {
a.reconnectingPTYServer = reconnectingpty.NewServer(
a.logger.Named("reconnecting-pty"),
a.sshServer,
func(id uuid.UUID, ip string) func(code int, reason string) {
return a.reportConnection(id, proto.Connection_RECONNECTING_PTY, ip)
},
a.metrics.connectionsTotal, a.metrics.reconnectingPTYErrors,
a.reconnectingPTYTimeout,
)
Expand Down Expand Up @@ -703,6 +730,91 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
}
}

// reportConnectionsLoop reports connections to the agent for auditing.
func (a *agent) reportConnectionsLoop(ctx context.Context, aAPI proto.DRPCAgentClient24) error {
for {
select {
case <-a.reportConnectionsUpdate:
case <-ctx.Done():
return ctx.Err()
}

for {
a.reportConnectionsMu.Lock()
if len(a.reportConnections) == 0 {
a.reportConnectionsMu.Unlock()
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a label here for clarity? break will always break the innermost loop but I always have trouble remembering that personally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you review in VS Code, the syntax highlighting can be helpful here!
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only use labels when I need to. I don't think we have examples in our code for labels breaking the inner loop. Do we?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You got me 😄 ❤️

}
payload := a.reportConnections[0]
a.reportConnectionsMu.Unlock()

logger := a.logger.With(slog.F("payload", payload))
logger.Debug(ctx, "reporting connection")
_, err := aAPI.ReportConnection(ctx, payload)
if err != nil {
return xerrors.Errorf("failed to report connection: %w", err)
}

logger.Debug(ctx, "successfully reported connection")

a.reportConnectionsMu.Lock()
a.reportConnections = a.reportConnections[1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just make reportConnections channel? If it's an append only slice that is only read in this function.

This slice behavior is correct, just feels like a weaker implementation of a channel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it could be a channel, but how big to make it? How many in-flight reports is "too many"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, a channel would not be a terrible option here. It requires upfront allocation which can either be a good or a bad thing in memory constrained systems. For now I've limited this to 2048 reports pending, or about 300KB. We can revisit this later if needed.

count := len(a.reportConnections)
a.reportConnectionsMu.Unlock()

if count == 0 {
break
}
}
}
}

func (a *agent) reportConnection(id uuid.UUID, connectionType proto.Connection_Type, ip string) (disconnected func(code int, reason string)) {
// Remove the port from the IP.
if portIndex := strings.LastIndex(ip, ":"); portIndex != -1 {
ip = ip[:portIndex]
ip = strings.Trim(ip, "[]") // IPv6 addresses are wrapped in brackets.
}

a.reportConnectionsMu.Lock()
defer a.reportConnectionsMu.Unlock()
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
Connection: &proto.Connection{
Id: id[:],
Action: proto.Connection_CONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: 0,
Reason: nil,
},
})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to worry about the size of this slice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically no. The only way these accumulate is if we have lost connection to coders. I could add some safe-guards around this to drop messages, but since it's part of auditing that feels a bit wrong. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll defer to your judegment. If you feel it is relatively bounded, then it's good with me. Maybe leave a comment of the assumptions?

I did not spend that long understanding the full context of the code.

select {
case a.reportConnectionsUpdate <- struct{}{}:
default:
}

return func(code int, reason string) {
a.reportConnectionsMu.Lock()
defer a.reportConnectionsMu.Unlock()
a.reportConnections = append(a.reportConnections, &proto.ReportConnectionRequest{
Connection: &proto.Connection{
Id: id[:],
Action: proto.Connection_DISCONNECT,
Type: connectionType,
Timestamp: timestamppb.New(time.Now()),
Ip: ip,
StatusCode: int32(code), //nolint:gosec
Reason: &reason,
},
})
select {
case a.reportConnectionsUpdate <- struct{}{}:
default:
}
}
}

// fetchServiceBannerLoop fetches the service banner on an interval. It will
// not be fetched immediately; the expectation is that it is primed elsewhere
// (and must be done before the session actually starts).
Expand Down Expand Up @@ -813,6 +925,10 @@ func (a *agent) run() (retErr error) {
return resourcesmonitor.Start(ctx)
})

// Connection reports are part of auditing, we should keep sending them via
// gracefulShutdownBehaviorRemain.
connMan.startAgentAPI("report connections", gracefulShutdownBehaviorRemain, a.reportConnectionsLoop)

// channels to sync goroutines below
// handle manifest
// |
Expand Down
74 changes: 66 additions & 8 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
Expand Down Expand Up @@ -189,6 +189,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
_ = stdin.Close()
err = session.Wait()
require.NoError(t, err)

assertConnectionReport(t, agentClient, proto.Connection_VSCODE, 0, "")
})

t.Run("TracksJetBrains", func(t *testing.T) {
Expand Down Expand Up @@ -225,7 +227,7 @@ func TestAgent_Stats_Magic(t *testing.T) {
remotePort := sc.Text()

//nolint:dogsled
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)

Expand Down Expand Up @@ -261,6 +263,8 @@ func TestAgent_Stats_Magic(t *testing.T) {
}, testutil.WaitLong, testutil.IntervalFast,
"never saw stats after conn closes",
)

assertConnectionReport(t, agentClient, proto.Connection_JETBRAINS, 0, "")
})
}

Expand Down Expand Up @@ -918,7 +922,7 @@ func TestAgent_SFTP(t *testing.T) {
home = "/" + strings.ReplaceAll(home, "\\", "/")
}
//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
Expand All @@ -941,6 +945,10 @@ func TestAgent_SFTP(t *testing.T) {
require.NoError(t, err)
_, err = os.Stat(tempFile)
require.NoError(t, err)

// Close the client to trigger disconnect event.
_ = client.Close()
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
}

func TestAgent_SCP(t *testing.T) {
Expand All @@ -950,7 +958,7 @@ func TestAgent_SCP(t *testing.T) {
defer cancel()

//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
sshClient, err := conn.SSHClient(ctx)
require.NoError(t, err)
defer sshClient.Close()
Expand All @@ -963,6 +971,10 @@ func TestAgent_SCP(t *testing.T) {
require.NoError(t, err)
_, err = os.Stat(tempFile)
require.NoError(t, err)

// Close the client to trigger disconnect event.
scpClient.Close()
assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "")
}

func TestAgent_FileTransferBlocked(t *testing.T) {
Expand All @@ -987,7 +999,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
defer cancel()

//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
Expand All @@ -996,6 +1008,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
_, err = sftp.NewClient(sshClient)
require.Error(t, err)
assertFileTransferBlocked(t, err.Error())

assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})

t.Run("SCP with go-scp package", func(t *testing.T) {
Expand All @@ -1005,7 +1019,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
defer cancel()

//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
Expand All @@ -1018,6 +1032,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
err = scpClient.CopyFile(context.Background(), strings.NewReader("hello world"), tempFile, "0755")
require.Error(t, err)
assertFileTransferBlocked(t, err.Error())

assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})

t.Run("Forbidden commands", func(t *testing.T) {
Expand All @@ -1031,7 +1047,7 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
defer cancel()

//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, func(_ *agenttest.Client, o *agent.Options) {
o.BlockFileTransfer = true
})
sshClient, err := conn.SSHClient(ctx)
Expand All @@ -1053,6 +1069,8 @@ func TestAgent_FileTransferBlocked(t *testing.T) {
msg, err := io.ReadAll(stdout)
require.NoError(t, err)
assertFileTransferBlocked(t, string(msg))

assertConnectionReport(t, agentClient, proto.Connection_SSH, agentssh.BlockedFileTransferErrorCode, "")
})
}
})
Expand Down Expand Up @@ -1661,8 +1679,16 @@ func TestAgent_ReconnectingPTY(t *testing.T) {
defer cancel()

//nolint:dogsled
conn, _, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
id := uuid.New()

// Test that the connection is reported. This must be tested in the
// first connection because we care about verifying all of these.
netConn0, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
require.NoError(t, err)
_ = netConn0.Close()
assertConnectionReport(t, agentClient, proto.Connection_RECONNECTING_PTY, 0, "")

// --norc disables executing .bashrc, which is often used to customize the bash prompt
netConn1, err := conn.ReconnectingPTY(ctx, id, 80, 80, "bash --norc")
require.NoError(t, err)
Expand Down Expand Up @@ -2691,3 +2717,35 @@ func requireEcho(t *testing.T, conn net.Conn) {
require.NoError(t, err)
require.Equal(t, "test", string(b))
}

func assertConnectionReport(t testing.TB, agentClient *agenttest.Client, connectionType proto.Connection_Type, status int, reason string) {
t.Helper()

var reports []*proto.ReportConnectionRequest
if !assert.Eventually(t, func() bool {
reports = agentClient.GetConnectionReports()
return len(reports) >= 2
}, testutil.WaitMedium, testutil.IntervalFast, "waiting for 2 connection reports or more; got %d", len(reports)) {
return
}

assert.Len(t, reports, 2, "want 2 connection reports")

assert.Equal(t, proto.Connection_CONNECT, reports[0].GetConnection().GetAction(), "first report should be connect")
assert.Equal(t, proto.Connection_DISCONNECT, reports[1].GetConnection().GetAction(), "second report should be disconnect")
assert.Equal(t, connectionType, reports[0].GetConnection().GetType(), "connect type should be %s", connectionType)
assert.Equal(t, connectionType, reports[1].GetConnection().GetType(), "disconnect type should be %s", connectionType)
t1 := reports[0].GetConnection().GetTimestamp().AsTime()
t2 := reports[1].GetConnection().GetTimestamp().AsTime()
assert.True(t, t1.Before(t2) || t1.Equal(t2), "connect timestamp should be before or equal to disconnect timestamp")
assert.NotEmpty(t, reports[0].GetConnection().GetIp(), "connect ip should not be empty")
assert.NotEmpty(t, reports[1].GetConnection().GetIp(), "disconnect ip should not be empty")
assert.Equal(t, 0, int(reports[0].GetConnection().GetStatusCode()), "connect status code should be 0")
assert.Equal(t, status, int(reports[1].GetConnection().GetStatusCode()), "disconnect status code should be %d", status)
assert.Equal(t, "", reports[0].GetConnection().GetReason(), "connect reason should be empty")
if reason != "" {
assert.Contains(t, reports[1].GetConnection().GetReason(), reason, "disconnect reason should contain %s", reason)
} else {
t.Logf("connection report disconnect reason: %s", reports[1].GetConnection().GetReason())
}
}
Loading
Loading
0