8000 feat: add session token injection to provisioner by sreya · Pull Request #7461 · coder/coder · GitHub
[go: up one dir, main page]

Skip to content

feat: add session token injection to provisioner #7461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update provisionerd test
  • Loading branch information
sreya committed May 16, 2023
commit 6ad169d80848b4b25d590992f6afcc364a304200
61 changes: 35 additions & 26 deletions coderd/provisionerdserver/provisionerdserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,18 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
}
}

sessionToken, err := server.regenerateSessionToken(ctx, owner, workspace)
if err != nil {
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
var sessionToken string
switch workspaceBuild.Transition {
case database.WorkspaceTransitionStart:
sessionToken, err = server.regenerateSessionToken(ctx, owner, workspace)
if err != nil {
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
}
case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete:
err = server.deleteSessionToken(ctx, workspace)
if err != nil {
return nil, failJob(fmt.Sprintf("delete session token: %s", err))
}
}

// Compute parameters for the workspace to consume.
Expand Down Expand Up @@ -1434,35 +1443,35 @@ func (server *Server) regenerateSessionToken(ctx context.Context, user database.
return "", xerrors.Errorf("generate API key: %w", err)
}

err = server.Database.InTx(
func(tx database.Store) error {
key, err := tx.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
UserID: workspace.OwnerID,
TokenName: workspaceSessionTokenName(workspace),
})
if err == nil {
err = tx.DeleteAPIKeyByID(ctx, key.ID)
if err != nil {
return xerrors.Errorf("delete api key: %w", err)
}
}
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get api key by name: %w", err)
}

_, err = tx.InsertAPIKey(ctx, newkey)
if err != nil {
return xerrors.Errorf("insert API key: %w", err)
}
err = server.deleteSessionToken(ctx, workspace)
if err != nil {
return "", xerrors.Errorf("delete session token: %w", err)
}

return nil
}, nil)
_, err = server.Database.InsertAPIKey(ctx, newkey)
if err != nil {
return "", xerrors.Errorf("regenerate API key: %w", err)
return "", xerrors.Errorf("insert API key: %w", err)
}

return secret, nil
}

func (server *Server) deleteSessionToken(ctx context.Context, workspace database.Workspace) error {
key, err := server.Database.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
UserID: workspace.OwnerID,
TokenName: workspaceSessionTokenName(workspace),
})
if err == nil {
err = server.Database.DeleteAPIKeyByID(ctx, key.ID)
}

if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf("get api key by name: %w", err)
}

return nil
}

// obtainOIDCAccessToken returns a valid OpenID Connect access token
// for the user if it's able to obtain one, otherwise it returns an empty string.
func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig httpmw.OAuth2Config, userID uuid.UUID) (string, error) {
Expand Down
59 changes: 54 additions & 5 deletions coderd/provisionerdserver/provisionerdserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,16 @@ func TestAcquireJob(t *testing.T) {
})),
})

published := make(chan struct{})
closeSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
close(published)
startPublished := make(chan struct{})
var closed bool
closeStartSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
if !closed {
close(startPublished)
closed = true
}
})
require.NoError(t, err)
defer closeSubscribe()
defer closeStartSubscribe()

var job *proto.AcquiredJob

Expand All @@ -218,7 +222,7 @@ func TestAcquireJob(t *testing.T) {
}
}

<-published
<-startPublished

got, err := json.Marshal(job.Type)
require.NoError(t, err)
Expand Down Expand Up @@ -271,7 +275,52 @@ func TestAcquireJob(t *testing.T) {
require.NoError(t, err)

require.JSONEq(t, string(want), string(got))

// Assert that we delete the session token whenever
// a stop is issued.
stopbuild := dbgen.WorkspaceBuild(t, srv.Database, database.WorkspaceBuild{
WorkspaceID: workspace.ID,
BuildNumber: 2,
JobID: uuid.New(),
TemplateVersionID: version.ID,
Transition: database.WorkspaceTransitionStop,
Reason: database.BuildReasonInitiator,
})
_ = dbgen.ProvisionerJob(t, srv.Database, database.ProvisionerJob{
ID: stopbuild.ID,
InitiatorID: user.ID,
Provisioner: database.ProvisionerTypeEcho,
StorageMethod: database.ProvisionerStorageMethodFile,
FileID: file.ID,
Type: database.ProvisionerJobTypeWorkspaceBuild,
Input: must(json.Marshal(provisionerdserver.WorkspaceProvisionJob{
WorkspaceBuildID: stopbuild.ID,
})),
})

stopPublished := make(chan struct{})
closeStopSubscribe, err := srv.Pubsub.Subscribe(codersdk.WorkspaceNotifyChannel(workspace.ID), func(_ context.Context, _ []byte) {
close(stopPublished)
})
require.NoError(t, err)
defer closeStopSubscribe()

// Grab jobs until we find the workspace build job. There is also
// an import version job that we need to ignore.
job, err = srv.AcquireJob(ctx, nil)
require.NoError(t, err)
_, ok := job.Type.(*proto.AcquiredJob_WorkspaceBuild_)
require.True(t, ok, "acquired job not a workspace build?")

<-stopPublished

// Validate that a session token is deleted during a stop job.
sessionToken = job.Type.(*proto.AcquiredJob_WorkspaceBuild_).WorkspaceBuild.Metadata.CoderSessionToken
require.Empty(t, sessionToken)
_, err = srv.Database.GetAPIKeyByID(ctx, key.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
})

t.Run("TemplateVersionDryRun", func(t *testing.T) {
t.Parallel()
srv := setup(t, false)
Expand Down
0