diff --git a/cli/server.go b/cli/server.go index 00e857677062d..59121e8c23d81 100644 --- a/cli/server.go +++ b/cli/server.go @@ -723,7 +723,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co // the request is not to a local IP. var handler http.Handler = coderAPI.RootHandler if cfg.RedirectToAccessURL.Value { - handler = redirectToAccessURL(handler, accessURLParsed, tunnel != nil) + handler = redirectToAccessURL(handler, accessURLParsed, tunnel != nil, appHostnameRegex) } // ReadHeaderTimeout is purposefully not enabled. It caused some @@ -1470,7 +1470,7 @@ func configureHTTPClient(ctx context.Context, clientCertFile, clientKeyFile stri } // nolint:revive -func redirectToAccessURL(handler http.Handler, accessURL *url.URL, tunnel bool) http.Handler { +func redirectToAccessURL(handler http.Handler, accessURL *url.URL, tunnel bool, appHostnameRegex *regexp.Regexp) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { redirect := func() { http.Redirect(w, r, accessURL.String(), http.StatusTemporaryRedirect) @@ -1484,12 +1484,17 @@ func redirectToAccessURL(handler http.Handler, accessURL *url.URL, tunnel bool) return } - if r.Host != accessURL.Host { - redirect() + if r.Host == accessURL.Host { + handler.ServeHTTP(w, r) + return + } + + if appHostnameRegex != nil && appHostnameRegex.MatchString(r.Host) { + handler.ServeHTTP(w, r) return } - handler.ServeHTTP(w, r) + redirect() }) } diff --git a/cli/server_test.go b/cli/server_test.go index 0a41d89d50cd4..17b14bd3dd916 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -32,6 +32,7 @@ import ( "github.com/coder/coder/cli/clitest" "github.com/coder/coder/cli/config" + "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/database/postgres" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" @@ -70,11 +71,7 @@ func TestServer(t *testing.T) { accessURL := waitAccessURL(t, cfg) client := codersdk.New(accessURL) - _, err = client.CreateFirstUser(ctx, codersdk.CreateFirstUserRequest{ - Email: "some@one.com", - Username: "example", - Password: "password", - }) + _, err = client.CreateFirstUser(ctx, coderdtest.FirstUserParams) require.NoError(t, err) cancelFunc() require.NoError(t, <-errC) @@ -540,6 +537,7 @@ func TestServer(t *testing.T) { tlsListener bool redirect bool accessURL string + requestURL string // Empty string means no redirect. expectRedirect string }{ @@ -558,6 +556,14 @@ func TestServer(t *testing.T) { accessURL: "https://example.com", expectRedirect: "", }, + { + name: "NoRedirectWithWildcard", + tlsListener: true, + accessURL: "https://example.com", + requestURL: "https://dev.example.com", + expectRedirect: "", + redirect: true, + }, { name: "NoTLSListener", httpListener: true, @@ -583,6 +589,10 @@ func TestServer(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + if c.requestURL == "" { + c.requestURL = c.accessURL + } + httpListenAddr := "" if c.httpListener { httpListenAddr = ":0" @@ -601,6 +611,7 @@ func TestServer(t *testing.T) { "--tls-address", ":0", "--tls-cert-file", certPath, "--tls-key-file", keyPath, + "--wildcard-access-url", "*.example.com", ) } if c.accessURL != "" { @@ -661,7 +672,7 @@ func TestServer(t *testing.T) { // Verify TLS if c.tlsListener { - accessURLParsed, err := url.Parse(c.accessURL) + accessURLParsed, err := url.Parse(c.requestURL) require.NoError(t, err) client := codersdk.New(accessURLParsed) client.HTTPClient = &http.Client{ @@ -679,8 +690,9 @@ func TestServer(t *testing.T) { } defer client.HTTPClient.CloseIdleConnections() _, err = client.HasFirstUser(ctx) - require.NoError(t, err) - + if err != nil { + require.ErrorContains(t, err, "Invalid application URL") + } cancelFunc() require.NoError(t, <-errC) }