diff --git a/coderd/autobuild/executor/lifecycle_executor.go b/coderd/autobuild/executor/lifecycle_executor.go index 161b85cf0f53f..1220a36419c0a 100644 --- a/coderd/autobuild/executor/lifecycle_executor.go +++ b/coderd/autobuild/executor/lifecycle_executor.go @@ -209,11 +209,17 @@ func build(ctx context.Context, store database.Store, workspace database.Workspa } provisionerJobID := uuid.New() now := database.Now() + + systemUser, err := store.GetUserByID(ctx, database.SystemUserID) + if err != nil { + return xerrors.Errorf("get system user: %w", err) + } + newProvisionerJob, err := store.InsertProvisionerJob(ctx, database.InsertProvisionerJobParams{ ID: provisionerJobID, CreatedAt: now, UpdatedAt: now, - InitiatorID: workspace.OwnerID, + InitiatorID: systemUser.ID, OrganizationID: template.OrganizationID, Provisioner: template.Provisioner, Type: database.ProvisionerJobTypeWorkspaceBuild, @@ -233,7 +239,7 @@ func build(ctx context.Context, store database.Store, workspace database.Workspa BuildNumber: priorBuildNumber + 1, Name: namesgenerator.GetRandomName(1), ProvisionerState: priorHistory.ProvisionerState, - InitiatorID: workspace.OwnerID, + InitiatorID: systemUser.ID, Transition: trans, JobID: newProvisionerJob.ID, }) diff --git a/coderd/autobuild/executor/lifecycle_executor_test.go b/coderd/autobuild/executor/lifecycle_executor_test.go index 1680fe0368e12..24fa896188cb7 100644 --- a/coderd/autobuild/executor/lifecycle_executor_test.go +++ b/coderd/autobuild/executor/lifecycle_executor_test.go @@ -2,7 +2,6 @@ package executor_test import ( "context" - "os" "testing" "time" @@ -483,7 +482,7 @@ func TestExecutorWorkspaceAutostopNoWaitChangedMyMind(t *testing.T) { } func TestExecutorAutostartMultipleOK(t *testing.T) { - if os.Getenv("DB") == "" { + if !coderdtest.UseSQL() { t.Skip(`This test only really works when using a "real" database, similar to a HA setup`) } diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 2eb25ccd2091a..1178ed39ce459 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -79,6 +79,11 @@ func New(t *testing.T, options *Options) *codersdk.Client { return client } +// UseSQL returns true if a Postgres server is running and can be used for tests. +func UseSQL() bool { + return os.Getenv("DB") != "" +} + // NewWithAPI constructs a codersdk client connected to the returned in-memory API instance. func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, *coderd.API) { if options == nil { @@ -105,7 +110,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, *coderd.API) // This can be hotswapped for a live database instance. db := databasefake.New() pubsub := database.NewPubsubInMemory() - if os.Getenv("DB") != "" { + if UseSQL() { connectionURL, closePg, err := postgres.Open() require.NoError(t, err) t.Cleanup(closePg) diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 228d89c2f4444..4e115ee14a4db 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -16,11 +16,21 @@ import ( // New returns an in-memory fake of the database. func New() database.Store { + systemUser := database.User{ + ID: database.SystemUserID, + Email: "system@coder.com", + Username: "system", + HashedPassword: make([]byte, 0), + CreatedAt: database.Now(), + UpdatedAt: database.Now(), + Status: database.UserStatusActive, + RBACRoles: make([]string, 0), + } return &fakeQuerier{ apiKeys: make([]database.APIKey, 0), organizationMembers: make([]database.OrganizationMember, 0), organizations: make([]database.Organization, 0), - users: make([]database.User, 0), + users: []database.User{systemUser}, auditLogs: make([]database.AuditLog, 0), files: make([]database.File, 0), @@ -179,11 +189,18 @@ func (q *fakeQuerier) GetUserByID(_ context.Context, id uuid.UUID) (database.Use return database.User{}, sql.ErrNoRows } -func (q *fakeQuerier) GetUserCount(_ context.Context) (int64, error) { +func (q *fakeQuerier) GetActualUserCount(_ context.Context) (int64, error) { q.mutex.RLock() defer q.mutex.RUnlock() - return int64(len(q.users)), nil + var count int64 + for _, user := range q.users { + if user.ID != database.SystemUserID { + count++ + } + } + + return count, nil } func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams) ([]database.User, error) { @@ -233,6 +250,16 @@ func (q *fakeQuerier) GetUsers(_ context.Context, params database.GetUsersParams users = tmp } + if !params.IncludeSystemUser { + tmp := make([]database.User, 0, len(users)) + for i, user := range users { + if user.ID != database.SystemUserID { + tmp = append(tmp, users[i]) + } + } + users = tmp + } + if len(params.Status) == 0 { params.Status = []database.UserStatus{database.UserStatusActive} } diff --git a/coderd/database/db.go b/coderd/database/db.go index 5236b18d65c82..bf6029f9a2558 100644 --- a/coderd/database/db.go +++ b/coderd/database/db.go @@ -13,9 +13,12 @@ import ( "database/sql" "errors" + "github.com/google/uuid" "golang.org/x/xerrors" ) +var SystemUserID uuid.UUID = uuid.MustParse("c0de2b07-0000-4000-A000-000000000000") + // Store contains all queryable database functions. // It extends the generated interface to add transaction support. type Store interface { diff --git a/coderd/database/migrations/000024_add_system_user.down.sql b/coderd/database/migrations/000024_add_system_user.down.sql new file mode 100644 index 0000000000000..974759c4bf971 --- /dev/null +++ b/coderd/database/migrations/000024_add_system_user.down.sql @@ -0,0 +1,6 @@ +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; + +DELETE FROM + users +WHERE + id = 'c0de2b07-0000-4000-A000-000000000000'; diff --git a/coderd/database/migrations/000024_add_system_user.up.sql b/coderd/database/migrations/000024_add_system_user.up.sql new file mode 100644 index 0000000000000..086a4316ec848 --- /dev/null +++ b/coderd/database/migrations/000024_add_system_user.up.sql @@ -0,0 +1,12 @@ +INSERT INTO + users ( + id, + email, + username, + hashed_password, + created_at, + updated_at, + rbac_roles + ) +VALUES + ('c0de2b07-0000-4000-A000-000000000000', 'system@coder.com', 'system', '', NOW(), NOW(), '{}'); diff --git a/coderd/database/querier.go b/coderd/database/querier.go index ffac6902a13e1..20fae32eb64fd 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -22,6 +22,8 @@ type querier interface { DeleteGitSSHKey(ctx context.Context, userID uuid.UUID) error DeleteParameterValueByID(ctx context.Context, id uuid.UUID) error GetAPIKeyByID(ctx context.Context, id string) (APIKey, error) + // Actual user count refers to the count of all users except the system user + GetActualUserCount(ctx context.Context) (int64, error) // GetAuditLogsBefore retrieves `limit` number of audit logs before the provided // ID. GetAuditLogsBefore(ctx context.Context, arg GetAuditLogsBeforeParams) ([]AuditLog, error) @@ -56,7 +58,6 @@ type querier interface { GetTemplatesWithFilter(ctx context.Context, arg GetTemplatesWithFilterParams) ([]Template, error) GetUserByEmailOrUsername(ctx context.Context, arg GetUserByEmailOrUsernameParams) (User, error) GetUserByID(ctx context.Context, id uuid.UUID) (User, error) - GetUserCount(ctx context.Context) (int64, error) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 8667a60ecfa74..035c04686bc7b 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2136,6 +2136,23 @@ func (q *sqlQuerier) UpdateTemplateVersionDescriptionByJobID(ctx context.Context return err } +const getActualUserCount = `-- name: GetActualUserCount :one +SELECT + COUNT(*) +FROM + users +WHERE + id != 'c0de2b07-0000-4000-A000-000000000000' +` + +// Actual user count refers to the count of all users except the system user +func (q *sqlQuerier) GetActualUserCount(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, getActualUserCount) + var count int64 + err := row.Scan(&count) + return count, err +} + const getAuthorizationUserRoles = `-- name: GetAuthorizationUserRoles :one SELECT -- username is returned just to help for logging purposes @@ -2237,20 +2254,6 @@ func (q *sqlQuerier) GetUserByID(ctx context.Context, id uuid.UUID) (User, error return i, err } -const getUserCount = `-- name: GetUserCount :one -SELECT - COUNT(*) -FROM - users -` - -func (q *sqlQuerier) GetUserCount(ctx context.Context) (int64, error) { - row := q.db.QueryRowContext(ctx, getUserCount) - var count int64 - err := row.Scan(&count) - return count, err -} - const getUsers = `-- name: GetUsers :many SELECT id, email, username, hashed_password, created_at, updated_at, status, rbac_roles @@ -2285,12 +2288,19 @@ WHERE ) ELSE true END + -- Filter out system user + AND CASE + WHEN $3 :: boolean THEN true + ELSE ( + id != 'c0de2b07-0000-4000-A000-000000000000' + ) + END -- Filter by status AND CASE -- @status needs to be a text because it can be empty, If it was -- user_status enum, it would not. - WHEN cardinality($3 :: user_status[]) > 0 THEN ( - status = ANY($3 :: user_status[]) + WHEN cardinality($4 :: user_status[]) > 0 THEN ( + status = ANY($4 :: user_status[]) ) ELSE -- Only show active by default @@ -2300,24 +2310,26 @@ WHERE ORDER BY -- Deterministic and consistent ordering of all users, even if they share -- a timestamp. This is to ensure consistent pagination. - (created_at, id) ASC OFFSET $4 + (created_at, id) ASC OFFSET $5 LIMIT -- A null limit means "no limit", so -1 means return all - NULLIF($5 :: int, -1) + NULLIF($6 :: int, -1) ` type GetUsersParams struct { - AfterID uuid.UUID `db:"after_id" json:"after_id"` - Search string `db:"search" json:"search"` - Status []UserStatus `db:"status" json:"status"` - OffsetOpt int32 `db:"offset_opt" json:"offset_opt"` - LimitOpt int32 `db:"limit_opt" json:"limit_opt"` + AfterID uuid.UUID `db:"after_id" json:"after_id"` + Search string `db:"search" json:"search"` + IncludeSystemUser bool `db:"include_system_user" json:"include_system_user"` + Status []UserStatus `db:"status" json:"status"` + OffsetOpt int32 `db:"offset_opt" json:"offset_opt"` + LimitOpt int32 `db:"limit_opt" json:"limit_opt"` } func (q *sqlQuerier) GetUsers(ctx context.Context, arg GetUsersParams) ([]User, error) { rows, err := q.db.QueryContext(ctx, getUsers, arg.AfterID, arg.Search, + arg.IncludeSystemUser, pq.Array(arg.Status), arg.OffsetOpt, arg.LimitOpt, diff --git a/coderd/database/queries/users.sql b/coderd/database/queries/users.sql index be4a42a1538b7..445950057be2c 100644 --- a/coderd/database/queries/users.sql +++ b/coderd/database/queries/users.sql @@ -22,11 +22,14 @@ WHERE LIMIT 1; --- name: GetUserCount :one +-- name: GetActualUserCount :one +-- Actual user count refers to the count of all users except the system user SELECT COUNT(*) FROM - users; + users +WHERE + id != 'c0de2b07-0000-4000-A000-000000000000'; -- name: InsertUser :one INSERT INTO @@ -104,6 +107,13 @@ WHERE ) ELSE true END + -- Filter out system user + AND CASE + WHEN @include_system_user :: boolean THEN true + ELSE ( + id != 'c0de2b07-0000-4000-A000-000000000000' + ) + END -- Filter by status AND CASE -- @status needs to be a text because it can be empty, If it was diff --git a/coderd/users.go b/coderd/users.go index 781cc213d1797..2cb86b75f7e57 100644 --- a/coderd/users.go +++ b/coderd/users.go @@ -27,7 +27,7 @@ import ( // Returns whether the initial user has been created or not. func (api *API) firstUser(rw http.ResponseWriter, r *http.Request) { - userCount, err := api.Database.GetUserCount(r.Context()) + userCount, err := api.Database.GetActualUserCount(r.Context()) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: "Internal error fetching user count.", @@ -56,7 +56,7 @@ func (api *API) postFirstUser(rw http.ResponseWriter, r *http.Request) { } // This should only function for the first user. - userCount, err := api.Database.GetUserCount(r.Context()) + userCount, err := api.Database.GetActualUserCount(r.Context()) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ Message: "Internal error fetching user count.", diff --git a/coderd/users_test.go b/coderd/users_test.go index c1ec00fd97378..6341506a45285 100644 --- a/coderd/users_test.go +++ b/coderd/users_test.go @@ -21,6 +21,34 @@ import ( "github.com/coder/coder/codersdk" ) +func TestSystemUser(t *testing.T) { + t.Parallel() + t.Run("SQLMatchesFake", func(t *testing.T) { + if !coderdtest.UseSQL() { + t.Skip("This test asserts that the system user is equivalent in SQL and the fake database.") + } + + t.Parallel() + + _, opts := coderdtest.NewWithAPI(t, nil) + fake := databasefake.New() + + fakeUser, _ := fake.GetUserByID(context.Background(), database.SystemUserID) + sqlUser, _ := opts.Database.GetUserByID(context.Background(), database.SystemUserID) + + // These fields are different as they use the actual timestamps at creation + fakeUser.CreatedAt, fakeUser.UpdatedAt = time.Time{}, time.Time{} + sqlUser.CreatedAt, sqlUser.UpdatedAt = time.Time{}, time.Time{} + + require.Equal(t, fakeUser, sqlUser) + }) + t.Run("ValidUUID", func(t *testing.T) { + t.Parallel() + require.Equal(t, uuid.Version(4), database.SystemUserID.Version()) + require.Equal(t, uuid.RFC4122, database.SystemUserID.Variant()) + }) +} + func TestFirstUser(t *testing.T) { t.Parallel() t.Run("BadRequest", func(t *testing.T) {