10000 feat(coderd): insert provisioner daemons by johnstcn · Pull Request #11207 · coder/coder · GitHub
[go: up one dir, main page]

Skip to content

feat(coderd): insert provisioner daemons #11207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add heartbeat
  • Loading branch information
johnstcn committed Dec 15, 2023
commit 6e7856afb228957268a44fe1c8b202502470b352
4 changes: 2 additions & 2 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -1190,9 +1190,9 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string
api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name))
logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name))
srv, err := provisionerdserver.NewServer(
api.ctx,
api.ctx, // use the same ctx as the API
api.AccessURL,
uuid.New(),
daemon.ID,
logger,
daemon.Provisioners,
provisionerdserver.Tags(daemon.Tags),
Expand Down
76 changes: 71 additions & 5 deletions coderd/provisionerdserver/provisionerdserver.go
10000
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,15 @@ import (
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
)

// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
// canceling and returning an empty job.
const DefaultAcquireJobLongPollDur = time.Second * 5
const (
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
// canceling and returning an empty job.
DefaultAcquireJobLongPollDur = time.Second * 5

// DefaultHeartbeatInterval is the interval at which the provisioner daemon
// will update its last seen at timestamp in the database.
DefaultHeartbeatInterval = time.Minute
)

type Options struct {
OIDCConfig httpmw.OAuth2Config
Expand All @@ -56,6 +62,15 @@ type Options struct {

// AcquireJobLongPollDur is used in tests
AcquireJobLongPollDur time.Duration

// HeartbeatInterval is the interval at which the provisioner daemon
// will update its last seen at timestamp in the database.
HeartbeatInterval time.Duration

// HeartbeatFn is the function that will be called at the interval
// specified by HeartbeatInterval.
// This is only used in tests.
HeartbeatFn func(context.Context) error
}

type server struct {
Expand Down Expand Up @@ -85,6 +100,9 @@ type server struct {
TimeNowFn func() time.Time

acquireJobLongPollDur time.Duration

HeartbeatInterval time.Duration
HeartbeatFn func(ctx context.Context) error
}

// We use the null byte (0x00) in generating a canonical map key for tags, so
Expand Down Expand Up @@ -161,7 +179,11 @@ func NewServer(
if options.AcquireJobLongPollDur == 0 {
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
}
return &server{
if options.HeartbeatInterval == 0 {
options.HeartbeatInterval = DefaultHeartbeatInterval
}

s := &server{
lifecycleCtx: lifecycleCtx,
AccessURL: accessURL,
ID: id,
Expand All @@ -182,7 +204,13 @@ func NewServer(
OIDCConfig: options.OIDCConfig,
TimeNowFn: options.TimeNowFn,
acquireJobLongPollDur: options.AcquireJobLongPollDur,
}, nil
HeartbeatInterval: options.HeartbeatInterval,
HeartbeatFn: options.HeartbeatFn,
}

go s.heartbeat()

return s, nil
}

// timeNow should be used when trying to get the current time for math
Expand All @@ -194,6 +222,44 @@ func (s *server) timeNow() time.Time {
return dbtime.Now()
}

// heartbeat runs heartbeatOnce at the interval specified by HeartbeatInterval
// until the lifecycle context is canceled.
func (s *server) heartbeat() {
tick := time.NewTicker(time.Nanosecond)
defer tick.Stop()
for {
select {
case <-s.lifecycleCtx.Done():
return
case <-tick.C:
hbCtx, hbCancel := context.WithTimeout(s.lifecycleCtx, s.HeartbeatInterval)
if err := s.heartbeatOnce(hbCtx); err != nil {
s.Logger.Error(hbCtx, "heartbeat failed", slog.Error(err))
}
hbCancel()
tick.Reset(s.HeartbeatInterval)
}
}
}

// heartbeatOnce updates the last seen at timestamp in the database.
// If HeartbeatFn is set, it will be called instead.
func (s *server) heartbeatOnce(ctx context.Context) error {
if s.HeartbeatFn != nil {
return s.HeartbeatFn(ctx)
}

if s.lifecycleCtx.Err() != nil {
return nil
}

//nolint:gocritic // Provisionerd has specific authz rules.
return s.Database.UpdateProvisionerDaemonLastSeenAt(dbauthz.AsProvisionerd(ctx), database.UpdateProvisionerDaemonLastSeenAtParams{
ID: s.ID,
LastSeenAt: sql.NullTime{Time: s.timeNow(), Valid: true},
})
}

// AcquireJob queries the database to lock a job.
//
// Deprecated: This method is only available for back-level provisioner daemons.
Expand Down
42 changes: 38 additions & 4 deletions coderd/provisionerdserver/provisionerdserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,29 @@ func TestAcquireJobWithCancel_Cancel(t *testing.T) {
require.Equal(t, "", job.JobId)
}

func TestHeartbeat(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
heartbeatChan := make(chan struct{})
heartbeatFn := func(context.Context) error {
heartbeatChan <- struct{}{}
return nil
}
//nolint:dogsled // this is a test
_, _, _ = setup(t, false, &overrides{
ctx: ctx,
heartbeatFn: heartbeatFn,
heartbeatInterval: testutil.IntervalFast,
})

<-heartbeatChan
cancel()
close(heartbeatChan)
<-time.After(testutil.IntervalFast)
}

func TestAcquireJob(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1686,19 +1709,20 @@ func TestInsertWorkspaceResource(t *testing.T) {
}

type overrides struct {
ctx context.Context
deploymentValues *codersdk.DeploymentValues
externalAuthConfigs []*externalauth.Config
id *uuid.UUID
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
timeNowFn func() time.Time
acquireJobLongPollDuration time.Duration
heartbeatFn func(ctx context.Context) error
heartbeatInterval time.Duration
}

func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
db := dbmem.New()
ps := pubsub.NewInMemory()
Expand All @@ -1710,6 +1734,14 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
var timeNowFn func() time.Time
pollDur := time.Duration(0)
if ov != nil {
if ov.ctx == nil {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
ov.ctx = ctx
}
if ov.heartbeatInterval == 0 {
ov.heartbeatInterval = testutil.IntervalMedium
}
if ov.deploymentValues != nil {
deploymentValues = ov.deploymentValues
}
Expand Down Expand Up @@ -1744,15 +1776,15 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
}

srv, err := provisionerdserver.NewServer(
ctx,
ov.ctx,
&url.URL{},
srvID,
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
[]database.ProvisionerType{database.ProvisionerTypeEcho},
provisionerdserver.Tags{},
db,
ps,
provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps),
provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps),
telemetry.NewNoop(),
trace.NewNoopTracerProvider().Tracer("noop"),
&atomic.Pointer[proto.QuotaCommitter]{},
Expand All @@ -1765,6 +1797,8 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
TimeNowFn: timeNowFn,
OIDCConfig: &oauth2.Config{},
AcquireJobLongPollDur: pollDur,
HeartbeatInterval: ov.heartbeatInterval,
HeartbeatFn: ov.heartbeatFn,
},
)
require.NoError(t, err)
Expand Down
9 changes: 6 additions & 3 deletions enterprise/coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}

// Create the daemon in the database.
_, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
Name: name,
Provisioners: provisioners,
Tags: tags,
Expand Down Expand Up @@ -295,11 +295,13 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}
mux := drpcmux.New()
logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name))
srvCtx, srvCancel := context.WithCancel(ctx)
defer srvCancel()
logger.Info(ctx, "starting external provisioner daemon")
srv, err := provisionerdserver.NewServer(
api.ctx,
srvCtx,
api.AccessURL,
id,
daemon.ID,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

review: another call-out: I am now setting the ID of the provisionerd server to be the same as the provisioner daemon that was inserted into the database, instead of a random UUID as it was before.

I believe this should be OK, but deferring to better judgement here.

logger,
provisioners,
tags,
Expand Down Expand Up @@ -339,6 +341,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
},
})
err = server.Serve(ctx, session)
srvCancel()
logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err))
if err != nil && !xerrors.Is(err, io.EOF) {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
Expand Down
0