From ad275dc27a1eda903ae8fa7d2d90a57ef83b070e Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Wed, 19 Feb 2025 18:30:33 +0100 Subject: [PATCH] fix(agent/agentssh): pin random seed for RSA key generation Change-Id: I8c7e3070324e5d558374fd6891eea9d48660e1e9 Signed-off-by: Thomas Kosiewski --- agent/agent.go | 44 +++++++- agent/agentssh/agentssh.go | 122 ++++++++++++++++++++--- agent/agentssh/agentssh_internal_test.go | 2 + agent/agentssh/agentssh_test.go | 8 ++ agent/agentssh/x11_test.go | 2 + cli/ssh_test.go | 65 ++++++++++++ 6 files changed, 226 insertions(+), 17 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 523892d3f65c9..0b3a6b3ecd2cf 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "hash/fnv" "io" "net/http" "net/netip" @@ -994,7 +995,6 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co if err := manifestOK.wait(ctx); err != nil { return xerrors.Errorf("no manifest: %w", err) } - var err error defer func() { networkOK.complete(retErr) }() @@ -1003,9 +1003,20 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co network := a.network a.closeMutex.Unlock() if network == nil { + keySeed, err := WorkspaceKeySeed(manifest.WorkspaceID, manifest.AgentName) + if err != nil { + return xerrors.Errorf("generate seed from workspace id: %w", err) + } // use the graceful context here, because creating the tailnet is not itself tied to the // agent API. - network, err = a.createTailnet(a.gracefulCtx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections) + network, err = a.createTailnet( + a.gracefulCtx, + manifest.AgentID, + manifest.DERPMap, + manifest.DERPForceWebSockets, + manifest.DisableDirectConnections, + keySeed, + ) if err != nil { return xerrors.Errorf("create tailnet: %w", err) } @@ -1145,7 +1156,13 @@ func (a *agent) trackGoroutine(fn func()) error { return nil } -func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, derpForceWebSockets, disableDirectConnections bool) (_ *tailnet.Conn, err error) { +func (a *agent) createTailnet( + ctx context.Context, + agentID uuid.UUID, + derpMap *tailcfg.DERPMap, + derpForceWebSockets, disableDirectConnections bool, + keySeed int64, +) (_ *tailnet.Conn, err error) { // Inject `CODER_AGENT_HEADER` into the DERP header. var header http.Header if client, ok := a.client.(*agentsdk.Client); ok { @@ -1172,6 +1189,10 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t } }() + if err := a.sshServer.UpdateHostSigner(keySeed); err != nil { + return nil, xerrors.Errorf("update host signer: %w", err) + } + sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort)) if err != nil { return nil, xerrors.Errorf("listen on the ssh port: %w", err) @@ -1849,3 +1870,20 @@ func PrometheusMetricsHandler(prometheusRegistry *prometheus.Registry, logger sl } }) } + +// WorkspaceKeySeed converts a WorkspaceID UUID and agent name to an int64 hash. +// This uses the FNV-1a hash algorithm which provides decent distribution and collision +// resistance for string inputs. +func WorkspaceKeySeed(workspaceID uuid.UUID, agentName string) (int64, error) { + h := fnv.New64a() + _, err := h.Write(workspaceID[:]) + if err != nil { + return 42, err + } + _, err = h.Write([]byte(agentName)) + if err != nil { + return 42, err + } + + return int64(h.Sum64()), nil +} diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 0f7d0adadc865..a7e028541aa6e 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -3,11 +3,12 @@ package agentssh import ( "bufio" "context" - "crypto/rand" "crypto/rsa" "errors" "fmt" "io" + "math/big" + "math/rand" "net" "os" "os/exec" @@ -128,17 +129,6 @@ type Server struct { } func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, execer agentexec.Execer, config *Config) (*Server, error) { - // Clients' should ignore the host key when connecting. - // The agent needs to authenticate with coderd to SSH, - // so SSH authentication doesn't improve security. - randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - randomSigner, err := gossh.NewSignerFromKey(randomHostKey) - if err != nil { - return nil, err - } if config == nil { config = &Config{} } @@ -205,8 +195,10 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom slog.F("local_addr", conn.LocalAddr()), slog.Error(err)) }, - Handler: s.sessionHandler, - HostSigners: []ssh.Signer{randomSigner}, + Handler: s.sessionHandler, + // HostSigners are intentionally empty, as the host key will + // be set before we start listening. + HostSigners: []ssh.Signer{}, LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { // Allow local port forwarding all! s.logger.Debug(ctx, "local port forward", @@ -844,7 +836,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string, return cmd, nil } +// Serve starts the server to handle incoming connections on the provided listener. +// It returns an error if no host keys are set or if there is an issue accepting connections. func (s *Server) Serve(l net.Listener) (retErr error) { + if len(s.srv.HostSigners) == 0 { + return xerrors.New("no host keys set") + } + s.logger.Info(context.Background(), "started serving listener", slog.F("listen_addr", l.Addr())) defer func() { s.logger.Info(context.Background(), "stopped serving listener", @@ -1099,3 +1097,99 @@ func userHomeDir() (string, error) { } return u.HomeDir, nil } + +// UpdateHostSigner updates the host signer with a new key generated from the provided seed. +// If an existing host key exists with the same algorithm, it is overwritten +func (s *Server) UpdateHostSigner(seed int64) error { + key, err := CoderSigner(seed) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.srv.AddHostKey(key) + + return nil +} + +// CoderSigner generates a deterministic SSH signer based on the provided seed. +// It uses RSA with a key size of 2048 bits. +func CoderSigner(seed int64) (gossh.Signer, error) { + // Clients should ignore the host key when connecting. + // The agent needs to authenticate with coderd to SSH, + // so SSH authentication doesn't improve security. + + // Since the standard lib purposefully does not generate + // deterministic rsa keys, we need to do it ourselves. + coderHostKey := func() *rsa.PrivateKey { + // Create deterministic random source + // nolint: gosec + deterministicRand := rand.New(rand.NewSource(seed)) + + // Use fixed values for p and q based on the seed + p := big.NewInt(0) + q := big.NewInt(0) + e := big.NewInt(65537) // Standard RSA public exponent + + // Generate deterministic primes using the seeded random + // Each prime should be ~1024 bits to get a 2048-bit key + for { + p.SetBit(p, 1024, 1) // Ensure it's large enough + for i := 0; i < 1024; i++ { + if deterministicRand.Int63()%2 == 1 { + p.SetBit(p, i, 1) + } else { + p.SetBit(p, i, 0) + } + } + if p.ProbablyPrime(20) { + break + } + } + + for { + q.SetBit(q, 1024, 1) // Ensure it's large enough + for i := 0; i < 1024; i++ { + if deterministicRand.Int63()%2 == 1 { + q.SetBit(q, i, 1) + } else { + q.SetBit(q, i, 0) + } + } + if q.ProbablyPrime(20) && p.Cmp(q) != 0 { + break + } + } + + // Calculate n = p * q + n := new(big.Int).Mul(p, q) + + // Calculate phi = (p-1) * (q-1) + p1 := new(big.Int).Sub(p, big.NewInt(1)) + q1 := new(big.Int).Sub(q, big.NewInt(1)) + phi := new(big.Int).Mul(p1, q1) + + // Calculate private exponent d + d := new(big.Int).ModInverse(e, phi) + + // Create the private key + privateKey := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: n, + E: int(e.Int64()), + }, + D: d, + Primes: []*big.Int{p, q}, + } + + // Compute precomputed values + privateKey.Precompute() + + return privateKey + }() + + coderSigner, err := gossh.NewSignerFromKey(coderHostKey) + return coderSigner, err +} diff --git a/agent/agentssh/agentssh_internal_test.go b/agent/agentssh/agentssh_internal_test.go index 0ffa45df19b0d..5a319fa0055c9 100644 --- a/agent/agentssh/agentssh_internal_test.go +++ b/agent/agentssh/agentssh_internal_test.go @@ -39,6 +39,8 @@ func Test_sessionStart_orphan(t *testing.T) { s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() + err = s.UpdateHostSigner(42) + assert.NoError(t, err) // Here we're going to call the handler directly with a faked SSH session // that just uses io.Pipes instead of a network socket. There is a large diff --git a/agent/agentssh/agentssh_test.go b/agent/agentssh/agentssh_test.go index b9cec420e5651..378657ebee5ad 100644 --- a/agent/agentssh/agentssh_test.go +++ b/agent/agentssh/agentssh_test.go @@ -41,6 +41,8 @@ func TestNewServer_ServeClient(t *testing.T) { s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() + err = s.UpdateHostSigner(42) + assert.NoError(t, err) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -146,6 +148,8 @@ func TestNewServer_CloseActiveConnections(t *testing.T) { s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() + err = s.UpdateHostSigner(42) + assert.NoError(t, err) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -197,6 +201,8 @@ func TestNewServer_Signal(t *testing.T) { s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() + err = s.UpdateHostSigner(42) + assert.NoError(t, err) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -262,6 +268,8 @@ func TestNewServer_Signal(t *testing.T) { s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) require.NoError(t, err) defer s.Close() + err = s.UpdateHostSigner(42) + assert.NoError(t, err) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/agent/agentssh/x11_test.go b/agent/agentssh/x11_test.go index 057da9a21e642..2ccbbfe69ca5c 100644 --- a/agent/agentssh/x11_test.go +++ b/agent/agentssh/x11_test.go @@ -38,6 +38,8 @@ func TestServer_X11(t *testing.T) { s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{}) require.NoError(t, err) defer s.Close() + err = s.UpdateHostSigner(42) + assert.NoError(t, err) ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) diff --git a/cli/ssh_test.go b/cli/ssh_test.go index b403f7ff83a8e..d20278bbf7ced 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -453,6 +453,71 @@ func TestSSH(t *testing.T) { <-cmdDone }) + t.Run("DeterministicHostKey", func(t *testing.T) { + t.Parallel() + client, workspace, agentToken := setupWorkspaceForAgent(t) + _, _ = tGoContext(t, func(ctx context.Context) { + // Run this async so the SSH command has to wait for + // the build and agent to connect! + _ = agenttest.New(t, client.URL, agentToken) + <-ctx.Done() + }) + + clientOutput, clientInput := io.Pipe() + serverOutput, serverInput := io.Pipe() + defer func() { + for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} { + _ = c.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name) + clitest.SetupConfig(t, client, root) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard + + cmdDone := tGo(t, func() { + err := inv.WithContext(ctx).Run() + assert.NoError(t, err) + }) + + keySeed, err := agent.WorkspaceKeySeed(workspace.ID, "dev") + assert.NoError(t, err) + + signer, err := agentssh.CoderSigner(keySeed) + assert.NoError(t, err) + + conn, channels, requests, err := ssh.NewClientConn(&stdioConn{ + Reader: serverOutput, + Writer: clientInput, + }, "", &ssh.ClientConfig{ + HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()), + }) + require.NoError(t, err) + defer conn.Close() + + sshClient := ssh.NewClient(conn, channels, requests) + session, err := sshClient.NewSession() + require.NoError(t, err) + defer session.Close() + + command := "sh -c exit" + if runtime.GOOS == "windows" { + command = "cmd.exe /c exit" + } + err = session.Run(command) + require.NoError(t, err) + err = sshClient.Close() + require.NoError(t, err) + _ = clientOutput.Close() + + <-cmdDone + }) + t.Run("NetworkInfo", func(t *testing.T) { t.Parallel() client, workspace, agentToken := setupWorkspaceForAgent(t)