From 38574ff22f3cf62a367a748a17ca2419772818fd Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 1 Oct 2024 16:55:10 +0000 Subject: [PATCH 01/13] feat: add wsproxy implementation for key fetching --- enterprise/wsproxy/keycache.go | 152 ++++++++ enterprise/wsproxy/keycache_test.go | 371 ++++++++++++++++++++ enterprise/wsproxy/wsproxysdk/wsproxysdk.go | 30 ++ 3 files changed, 553 insertions(+) create mode 100644 enterprise/wsproxy/keycache.go create mode 100644 enterprise/wsproxy/keycache_test.go diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go new file mode 100644 index 0000000000000..2571ce7b3cc8a --- /dev/null +++ b/enterprise/wsproxy/keycache.go @@ -0,0 +1,152 @@ +package wsproxy + +import ( + "context" + "sync" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" + "github.com/coder/quartz" +) + +type CryptoKeyCache struct { + client *wsproxysdk.Client + logger slog.Logger + Clock quartz.Clock + + keysMu sync.RWMutex + keys map[int32]wsproxysdk.CryptoKey + latest wsproxysdk.CryptoKey +} + +func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.Client, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { + cache := &CryptoKeyCache{ + client: client, + logger: log, + Clock: quartz.NewReal(), + } + + for _, opt := range opts { + opt(cache) + } + + m, latest, err := cache.fetch(ctx) + if err != nil { + return nil, xerrors.Errorf("initial fetch: %w", err) + } + cache.keys, cache.latest = m, latest + + go cache.refresh(ctx) + + return cache, nil +} + +func (k *CryptoKeyCache) Latest(ctx context.Context) (wsproxysdk.CryptoKey, error) { + k.keysMu.RLock() + latest := k.latest + k.keysMu.RUnlock() + + now := k.Clock.Now().UTC() + if latest.Active(now) { + return latest, nil + } + + k.keysMu.Lock() + defer k.keysMu.Unlock() + + if k.latest.Active(now) { + return k.latest, nil + } + + var err error + k.keys, k.latest, err = k.fetch(ctx) + if err != nil { + return wsproxysdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + } + + if !k.latest.Active(now) { + return wsproxysdk.CryptoKey{}, xerrors.Errorf("no active keys found") + } + + return k.latest, nil +} + +func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (wsproxysdk.CryptoKey, error) { + now := k.Clock.Now().UTC() + k.keysMu.RLock() + key, ok := k.keys[sequence] + k.keysMu.RUnlock() + if ok { + return validKey(key, now) + } + + k.keysMu.Lock() + defer k.keysMu.Unlock() + key, ok = k.keys[sequence] + if ok { + return validKey(key, now) + } + + var err error + k.keys, k.latest, err = k.fetch(ctx) + if err != nil { + return wsproxysdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + } + + key, ok = k.keys[sequence] + if !ok { + return wsproxysdk.CryptoKey{}, xerrors.Errorf("key %d not found", sequence) + } + + return validKey(key, now) +} + +func (k *CryptoKeyCache) refresh(ctx context.Context) { + k.Clock.TickerFunc(ctx, time.Minute*10, func() error { + kmap, latest, err := k.fetch(ctx) + if err != nil { + k.logger.Error(ctx, "failed to fetch crypto keys", slog.Error(err)) + return nil + } + + k.keysMu.Lock() + defer k.keysMu.Unlock() + k.keys = kmap + k.latest = latest + return nil + }) +} + +func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]wsproxysdk.CryptoKey, wsproxysdk.CryptoKey, error) { + keys, err := k.client.CryptoKeys(ctx) + if err != nil { + return nil, wsproxysdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err) + } + + kmap, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now().UTC()) + return kmap, latest, nil +} + +func toKeyMap(keys []wsproxysdk.CryptoKey, now time.Time) (map[int32]wsproxysdk.CryptoKey, wsproxysdk.CryptoKey) { + m := make(map[int32]wsproxysdk.CryptoKey) + var latest wsproxysdk.CryptoKey + for _, key := range keys { + m[key.Sequence] = key + if key.Sequence > latest.Sequence && key.Active(now) { + latest = key + } + } + return m, latest +} + +func validKey(key wsproxysdk.CryptoKey, now time.Time) (wsproxysdk.CryptoKey, error) { + if key.Invalid(now) { + return wsproxysdk.CryptoKey{}, xerrors.Errorf("key %d is invalid", key.Sequence) + } + + return key, nil +} diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go new file mode 100644 index 0000000000000..04b39bd866d91 --- /dev/null +++ b/enterprise/wsproxy/keycache_test.go @@ -0,0 +1,371 @@ +package wsproxy_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/v2/enterprise/wsproxy" + "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestCryptoKeyCache(t *testing.T) { + t.Parallel() + + t.Run("Latest", func(t *testing.T) { + t.Parallel() + + t.Run("HitsCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 2, + StartsAt: now, + } + + fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + { + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 1, + StartsAt: now, + }, + // Should be ignored since it hasn't breached its starts_at time yet. + { + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key3", + Sequence: 3, + StartsAt: now.Add(time.Second * 2), + }, + expected, + }) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + got, err := cache.Latest(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, fc.called) + }) + + t.Run("MissesCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{}) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + expected := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: clock.Now().UTC(), + } + fc.keys = []wsproxysdk.CryptoKey{expected} + + got, err := cache.Latest(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + // 1 on startup + missing cache. + require.Equal(t, 2, fc.called) + + // Ensure the cache gets hit this time. + got, err = cache.Latest(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + // 1 on startup + missing cache. + require.Equal(t, 2, fc.called) + }) + + t.Run("IgnoresInvalid", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + now := clock.Now().UTC() + expected := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 1, + StartsAt: clock.Now().UTC(), + } + + fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + expected, + { + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 2, + StartsAt: now.Add(-time.Second), + DeletesAt: now, + }, + }) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + got, err := cache.Latest(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, fc.called) + }) + }) + + t.Run("Version", func(t *testing.T) { + t.Parallel() + + t.Run("HitsCache", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + } + fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + expected, + { + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + }, + }) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + got, err := cache.Version(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, fc.called) + }) + + t.Run("MissesCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{}) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + expected := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: clock.Now().UTC(), + } + fc.keys = []wsproxysdk.CryptoKey{expected} + + got, err := cache.Version(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 2, fc.called) + + // Ensure the cache gets hit this time. + got, err = cache.Version(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 2, fc.called) + }) + + t.Run("AllowsBeforeStartsAt", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now.Add(-time.Second), + } + + fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + expected, + }) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + got, err := cache.Version(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, fc.called) + }) + + t.Run("NoInvalid", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now.Add(-time.Second), + DeletesAt: now, + } + + fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + expected, + }) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + _, err = cache.Version(ctx, expected.Sequence) + require.Error(t, err) + require.Equal(t, 1, fc.called) + }) + }) + + t.Run("CacheRefreshes", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + trap := clock.Trap().TickerFunc() + + now := clock.Now().UTC() + expected := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + DeletesAt: now.Add(time.Minute * 10), + } + fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + expected, + }) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + got, err := cache.Latest(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, fc.called) + + wait := trap.MustWait(ctx) + + newKey := wsproxysdk.CryptoKey{ + Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + } + fc.keys = []wsproxysdk.CryptoKey{newKey} + + wait.Release() + + // The ticker should fire and cause a request to coderd. + _, advance := clock.AdvanceNext() + advance.MustWait(ctx) + require.Equal(t, 2, fc.called) + + // Assert hits cache. + got, err = cache.Latest(ctx) + require.NoError(t, err) + require.Equal(t, newKey, got) + require.Equal(t, 2, fc.called) + + // Assert we do not have the old key. + _, err = cache.Version(ctx, expected.Sequence) + require.Error(t, err) + }) +} + +type fakeCoderd struct { + server *httptest.Server + keys []wsproxysdk.CryptoKey + called int + url *url.URL +} + +func newFakeCoderd(t *testing.T, keys []wsproxysdk.CryptoKey) *fakeCoderd { + t.Helper() + + c := &fakeCoderd{ + keys: keys, + } + + mux := http.NewServeMux() + mux.HandleFunc("/api/v2/workspaceproxies/me/crypto-keys", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(wsproxysdk.CryptoKeysResponse{ + CryptoKeys: c.keys, + }) + require.NoError(t, err) + c.called++ + }) + + c.server = httptest.NewServer(mux) + t.Cleanup(c.server.Close) + + var err error + c.url, err = url.Parse(c.server.URL) + require.NoError(t, err) + + return c +} + +func withClock(clock quartz.Clock) func(*wsproxy.CryptoKeyCache) { + return func(cache *wsproxy.CryptoKeyCache) { + cache.Clock = clock + } +} diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 77d36561c6de8..0da47068fa26e 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -204,7 +204,37 @@ type RegisterWorkspaceProxyRequest struct { Version string `json:"version"` } +type CryptoKeyFeature string + +const ( + CryptoKeyFeatureWorkspaceApp CryptoKeyFeature = "workspace_apps" + CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert" + CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" +) + +type CryptoKey struct { + Feature CryptoKeyFeature `json:"feature"` + Secret string `json:"secret"` + DeletesAt time.Time `json:"deletes_at"` + Sequence int32 `json:"sequence"` + StartsAt time.Time `json:"starts_at"` +} + +func (c CryptoKey) Active(now time.Time) bool { + now = now.UTC() + isAfterStartsAt := !c.StartsAt.IsZero() && !now.Before(c.StartsAt) + return isAfterStartsAt && !c.Invalid(now) +} + +func (c CryptoKey) Invalid(now time.Time) bool { + now = now.UTC() + noSecret := c.Secret == "" + afterDelete := !c.DeletesAt.IsZero() && !now.Before(c.DeletesAt.UTC()) + return noSecret || afterDelete +} + type RegisterWorkspaceProxyResponse struct { + Keys []CryptoKey `json:"keys"` AppSecurityKey string `json:"app_security_key"` DERPMeshKey string `json:"derp_mesh_key"` DERPRegionID int32 `json:"derp_region_id"` From ea5ec77b18877c172faf7ca1c7f48b72fa3fd76d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 1 Oct 2024 17:02:18 +0000 Subject: [PATCH 02/13] Refactor CryptoKey usage to use codersdk package This change streamlines the CryptoKey handling by utilizing the codersdk package, thereby reducing code duplication and potentially simplifying maintenance efforts in the future. --- enterprise/wsproxy/keycache.go | 41 ++++++------ enterprise/wsproxy/keycache_test.go | 71 +++++++++++---------- enterprise/wsproxy/wsproxysdk/wsproxysdk.go | 30 --------- 3 files changed, 57 insertions(+), 85 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index 2571ce7b3cc8a..cbb79eadb1c7c 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -9,6 +9,7 @@ import ( "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" "github.com/coder/quartz" ) @@ -19,8 +20,8 @@ type CryptoKeyCache struct { Clock quartz.Clock keysMu sync.RWMutex - keys map[int32]wsproxysdk.CryptoKey - latest wsproxysdk.CryptoKey + keys map[int32]codersdk.CryptoKey + latest codersdk.CryptoKey } func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.Client, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { @@ -45,37 +46,37 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk. return cache, nil } -func (k *CryptoKeyCache) Latest(ctx context.Context) (wsproxysdk.CryptoKey, error) { +func (k *CryptoKeyCache) Latest(ctx context.Context) (codersdk.CryptoKey, error) { k.keysMu.RLock() latest := k.latest k.keysMu.RUnlock() now := k.Clock.Now().UTC() - if latest.Active(now) { + if latest.CanSign(now) { return latest, nil } k.keysMu.Lock() defer k.keysMu.Unlock() - if k.latest.Active(now) { + if k.latest.CanSign(now) { return k.latest, nil } var err error k.keys, k.latest, err = k.fetch(ctx) if err != nil { - return wsproxysdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) } - if !k.latest.Active(now) { - return wsproxysdk.CryptoKey{}, xerrors.Errorf("no active keys found") + if !k.latest.CanSign(now) { + return codersdk.CryptoKey{}, xerrors.Errorf("no active keys found") } return k.latest, nil } -func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (wsproxysdk.CryptoKey, error) { +func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { now := k.Clock.Now().UTC() k.keysMu.RLock() key, ok := k.keys[sequence] @@ -94,12 +95,12 @@ func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (wsproxysd var err error k.keys, k.latest, err = k.fetch(ctx) if err != nil { - return wsproxysdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) } key, ok = k.keys[sequence] if !ok { - return wsproxysdk.CryptoKey{}, xerrors.Errorf("key %d not found", sequence) + return codersdk.CryptoKey{}, xerrors.Errorf("key %d not found", sequence) } return validKey(key, now) @@ -121,31 +122,31 @@ func (k *CryptoKeyCache) refresh(ctx context.Context) { }) } -func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]wsproxysdk.CryptoKey, wsproxysdk.CryptoKey, error) { +func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { keys, err := k.client.CryptoKeys(ctx) if err != nil { - return nil, wsproxysdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err) + return nil, codersdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err) } kmap, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now().UTC()) return kmap, latest, nil } -func toKeyMap(keys []wsproxysdk.CryptoKey, now time.Time) (map[int32]wsproxysdk.CryptoKey, wsproxysdk.CryptoKey) { - m := make(map[int32]wsproxysdk.CryptoKey) - var latest wsproxysdk.CryptoKey +func toKeyMap(keys []codersdk.CryptoKey, now time.Time) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey) { + m := make(map[int32]codersdk.CryptoKey) + var latest codersdk.CryptoKey for _, key := range keys { m[key.Sequence] = key - if key.Sequence > latest.Sequence && key.Active(now) { + if key.Sequence > latest.Sequence && key.CanSign(now) { latest = key } } return m, latest } -func validKey(key wsproxysdk.CryptoKey, now time.Time) (wsproxysdk.CryptoKey, error) { - if key.Invalid(now) { - return wsproxysdk.CryptoKey{}, xerrors.Errorf("key %d is invalid", key.Sequence) +func validKey(key codersdk.CryptoKey, now time.Time) (codersdk.CryptoKey, error) { + if !key.CanSign(now) { + return codersdk.CryptoKey{}, xerrors.Errorf("key %d is invalid", key.Sequence) } return key, nil diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index 04b39bd866d91..de66bc74d3cca 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -12,6 +12,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/wsproxy" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" "github.com/coder/coder/v2/testutil" @@ -33,23 +34,23 @@ func TestCryptoKeyCache(t *testing.T) { ) now := clock.Now().UTC() - expected := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key2", Sequence: 2, StartsAt: now, } - fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + fc := newFakeCoderd(t, []codersdk.CryptoKey{ { - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key1", Sequence: 1, StartsAt: now, }, // Should be ignored since it hasn't breached its starts_at time yet. { - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key3", Sequence: 3, StartsAt: now.Add(time.Second * 2), @@ -74,18 +75,18 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) - fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{}) + fc := newFakeCoderd(t, []codersdk.CryptoKey{}) cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) require.NoError(t, err) - expected := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key1", Sequence: 12, StartsAt: clock.Now().UTC(), } - fc.keys = []wsproxysdk.CryptoKey{expected} + fc.keys = []codersdk.CryptoKey{expected} got, err := cache.Latest(ctx) require.NoError(t, err) @@ -110,17 +111,17 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) now := clock.Now().UTC() - expected := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key1", Sequence: 1, StartsAt: clock.Now().UTC(), } - fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + fc := newFakeCoderd(t, []codersdk.CryptoKey{ expected, { - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key2", Sequence: 2, StartsAt: now.Add(-time.Second), @@ -151,16 +152,16 @@ func TestCryptoKeyCache(t *testing.T) { ) now := clock.Now().UTC() - expected := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key1", Sequence: 12, StartsAt: now, } - fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + fc := newFakeCoderd(t, []codersdk.CryptoKey{ expected, { - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key2", Sequence: 13, StartsAt: now, @@ -184,18 +185,18 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) - fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{}) + fc := newFakeCoderd(t, []codersdk.CryptoKey{}) cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) require.NoError(t, err) - expected := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key1", Sequence: 12, StartsAt: clock.Now().UTC(), } - fc.keys = []wsproxysdk.CryptoKey{expected} + fc.keys = []codersdk.CryptoKey{expected} got, err := cache.Version(ctx, expected.Sequence) require.NoError(t, err) @@ -219,14 +220,14 @@ func TestCryptoKeyCache(t *testing.T) { ) now := clock.Now().UTC() - expected := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key1", Sequence: 12, StartsAt: now.Add(-time.Second), } - fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + fc := newFakeCoderd(t, []codersdk.CryptoKey{ expected, }) @@ -249,15 +250,15 @@ func TestCryptoKeyCache(t *testing.T) { ) now := clock.Now().UTC() - expected := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key1", Sequence: 12, StartsAt: now.Add(-time.Second), DeletesAt: now, } - fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + fc := newFakeCoderd(t, []codersdk.CryptoKey{ expected, }) @@ -282,14 +283,14 @@ func TestCryptoKeyCache(t *testing.T) { trap := clock.Trap().TickerFunc() now := clock.Now().UTC() - expected := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key1", Sequence: 12, StartsAt: now, DeletesAt: now.Add(time.Minute * 10), } - fc := newFakeCoderd(t, []wsproxysdk.CryptoKey{ + fc := newFakeCoderd(t, []codersdk.CryptoKey{ expected, }) @@ -303,13 +304,13 @@ func TestCryptoKeyCache(t *testing.T) { wait := trap.MustWait(ctx) - newKey := wsproxysdk.CryptoKey{ - Feature: wsproxysdk.CryptoKeyFeatureWorkspaceApp, + newKey := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key2", Sequence: 13, StartsAt: now, } - fc.keys = []wsproxysdk.CryptoKey{newKey} + fc.keys = []codersdk.CryptoKey{newKey} wait.Release() @@ -332,12 +333,12 @@ func TestCryptoKeyCache(t *testing.T) { type fakeCoderd struct { server *httptest.Server - keys []wsproxysdk.CryptoKey + keys []codersdk.CryptoKey called int url *url.URL } -func newFakeCoderd(t *testing.T, keys []wsproxysdk.CryptoKey) *fakeCoderd { +func newFakeCoderd(t *testing.T, keys []codersdk.CryptoKey) *fakeCoderd { t.Helper() c := &fakeCoderd{ diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 0da47068fa26e..77d36561c6de8 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -204,37 +204,7 @@ type RegisterWorkspaceProxyRequest struct { Version string `json:"version"` } -type CryptoKeyFeature string - -const ( - CryptoKeyFeatureWorkspaceApp CryptoKeyFeature = "workspace_apps" - CryptoKeyFeatureOIDCConvert CryptoKeyFeature = "oidc_convert" - CryptoKeyFeatureTailnetResume CryptoKeyFeature = "tailnet_resume" -) - -type CryptoKey struct { - Feature CryptoKeyFeature `json:"feature"` - Secret string `json:"secret"` - DeletesAt time.Time `json:"deletes_at"` - Sequence int32 `json:"sequence"` - StartsAt time.Time `json:"starts_at"` -} - -func (c CryptoKey) Active(now time.Time) bool { - now = now.UTC() - isAfterStartsAt := !c.StartsAt.IsZero() && !now.Before(c.StartsAt) - return isAfterStartsAt && !c.Invalid(now) -} - -func (c CryptoKey) Invalid(now time.Time) bool { - now = now.UTC() - noSecret := c.Secret == "" - afterDelete := !c.DeletesAt.IsZero() && !now.Before(c.DeletesAt.UTC()) - return noSecret || afterDelete -} - type RegisterWorkspaceProxyResponse struct { - Keys []CryptoKey `json:"keys"` AppSecurityKey string `json:"app_security_key"` DERPMeshKey string `json:"derp_mesh_key"` DERPRegionID int32 `json:"derp_region_id"` From e6612bd79752e3894459f56f66cffd50a18999eb Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 1 Oct 2024 23:00:19 +0000 Subject: [PATCH 03/13] Refactor wsproxy to use cryptokeys interface - Implement `cryptokeys.Keycache` interface in `CryptoKeyCache`. - Introduce context management for graceful shutdowns. - Simplify function signatures and improve concurrency handling. - Ensure functions return errors when cache is closed. --- enterprise/wsproxy/keycache.go | 53 +++++++++++++++++++--- enterprise/wsproxy/keycache_test.go | 68 ++++++++++++++++++++++++----- 2 files changed, 102 insertions(+), 19 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index cbb79eadb1c7c..cff7138b83f75 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -9,12 +9,17 @@ import ( "cdr.dev/slog" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" "github.com/coder/quartz" ) +var _ cryptokeys.Keycache = &CryptoKeyCache{} + type CryptoKeyCache struct { + ctx context.Context + cancel context.CancelFunc client *wsproxysdk.Client logger slog.Logger Clock quartz.Clock @@ -22,6 +27,7 @@ type CryptoKeyCache struct { keysMu sync.RWMutex keys map[int32]codersdk.CryptoKey latest codersdk.CryptoKey + closed bool } func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.Client, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { @@ -40,14 +46,21 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk. return nil, xerrors.Errorf("initial fetch: %w", err) } cache.keys, cache.latest = m, latest + cache.ctx, cache.cancel = context.WithCancel(ctx) - go cache.refresh(ctx) + go cache.refresh() return cache, nil } -func (k *CryptoKeyCache) Latest(ctx context.Context) (codersdk.CryptoKey, error) { +func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { k.keysMu.RLock() + + if k.closed { + k.keysMu.RUnlock() + return codersdk.CryptoKey{}, cryptokeys.ErrClosed + } + latest := k.latest k.keysMu.RUnlock() @@ -59,6 +72,10 @@ func (k *CryptoKeyCache) Latest(ctx context.Context) (codersdk.CryptoKey, error) k.keysMu.Lock() defer k.keysMu.Unlock() + if k.closed { + return codersdk.CryptoKey{}, cryptokeys.ErrClosed + } + if k.latest.CanSign(now) { return k.latest, nil } @@ -76,9 +93,14 @@ func (k *CryptoKeyCache) Latest(ctx context.Context) (codersdk.CryptoKey, error) return k.latest, nil } -func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { +func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { now := k.Clock.Now().UTC() k.keysMu.RLock() + if k.closed { + k.keysMu.RUnlock() + return codersdk.CryptoKey{}, cryptokeys.ErrClosed + } + key, ok := k.keys[sequence] k.keysMu.RUnlock() if ok { @@ -87,6 +109,11 @@ func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (codersdk. k.keysMu.Lock() defer k.keysMu.Unlock() + + if k.closed { + return codersdk.CryptoKey{}, cryptokeys.ErrClosed + } + key, ok = k.keys[sequence] if ok { return validKey(key, now) @@ -106,11 +133,11 @@ func (k *CryptoKeyCache) Version(ctx context.Context, sequence int32) (codersdk. return validKey(key, now) } -func (k *CryptoKeyCache) refresh(ctx context.Context) { - k.Clock.TickerFunc(ctx, time.Minute*10, func() error { - kmap, latest, err := k.fetch(ctx) +func (k *CryptoKeyCache) refresh() { + k.Clock.TickerFunc(k.ctx, time.Minute*10, func() error { + kmap, latest, err := k.fetch(k.ctx) if err != nil { - k.logger.Error(ctx, "failed to fetch crypto keys", slog.Error(err)) + k.logger.Error(k.ctx, "failed to fetch crypto keys", slog.Error(err)) return nil } @@ -151,3 +178,15 @@ func validKey(key codersdk.CryptoKey, now time.Time) (codersdk.CryptoKey, error) return key, nil } + +func (k *CryptoKeyCache) Close() { + k.keysMu.Lock() + defer k.keysMu.Unlock() + + if k.closed { + return + } + + k.cancel() + k.closed = true +} diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index de66bc74d3cca..7267c1da1e033 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -12,6 +12,7 @@ import ( "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/wsproxy" "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" @@ -61,7 +62,7 @@ func TestCryptoKeyCache(t *testing.T) { cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) require.NoError(t, err) - got, err := cache.Latest(ctx) + got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) require.Equal(t, 1, fc.called) @@ -88,14 +89,14 @@ func TestCryptoKeyCache(t *testing.T) { } fc.keys = []codersdk.CryptoKey{expected} - got, err := cache.Latest(ctx) + got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) // 1 on startup + missing cache. require.Equal(t, 2, fc.called) // Ensure the cache gets hit this time. - got, err = cache.Latest(ctx) + got, err = cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) // 1 on startup + missing cache. @@ -132,7 +133,7 @@ func TestCryptoKeyCache(t *testing.T) { cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) require.NoError(t, err) - got, err := cache.Latest(ctx) + got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) require.Equal(t, 1, fc.called) @@ -171,7 +172,7 @@ func TestCryptoKeyCache(t *testing.T) { cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) require.NoError(t, err) - got, err := cache.Version(ctx, expected.Sequence) + got, err := cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) require.Equal(t, 1, fc.called) @@ -198,13 +199,13 @@ func TestCryptoKeyCache(t *testing.T) { } fc.keys = []codersdk.CryptoKey{expected} - got, err := cache.Version(ctx, expected.Sequence) + got, err := cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) require.Equal(t, 2, fc.called) // Ensure the cache gets hit this time. - got, err = cache.Version(ctx, expected.Sequence) + got, err = cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) require.Equal(t, 2, fc.called) @@ -234,7 +235,7 @@ func TestCryptoKeyCache(t *testing.T) { cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) require.NoError(t, err) - got, err := cache.Version(ctx, expected.Sequence) + got, err := cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) require.Equal(t, 1, fc.called) @@ -265,7 +266,7 @@ func TestCryptoKeyCache(t *testing.T) { cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) require.NoError(t, err) - _, err = cache.Version(ctx, expected.Sequence) + _, err = cache.Verifying(ctx, expected.Sequence) require.Error(t, err) require.Equal(t, 1, fc.called) }) @@ -297,7 +298,7 @@ func TestCryptoKeyCache(t *testing.T) { cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) require.NoError(t, err) - got, err := cache.Latest(ctx) + got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) require.Equal(t, 1, fc.called) @@ -320,15 +321,58 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, 2, fc.called) // Assert hits cache. - got, err = cache.Latest(ctx) + got, err = cache.Signing(ctx) require.NoError(t, err) require.Equal(t, newKey, got) require.Equal(t, 2, fc.called) // Assert we do not have the old key. - _, err = cache.Version(ctx, expected.Sequence) + _, err = cache.Verifying(ctx, expected.Sequence) require.Error(t, err) }) + + t.Run("Closed", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + DeletesAt: now.Add(time.Minute * 10), + } + fc := newFakeCoderd(t, []codersdk.CryptoKey{ + expected, + }) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + got, err := cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, fc.called) + + got, err = cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, fc.called) + + cache.Close() + + _, err = cache.Signing(ctx) + require.ErrorIs(t, err, cryptokeys.ErrClosed) + + _, err = cache.Verifying(ctx, expected.Sequence) + require.ErrorIs(t, err, cryptokeys.ErrClosed) + }) } type fakeCoderd struct { From 558ad30e8ec5ff3f1565dad03ca2985235503e66 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Tue, 1 Oct 2024 23:25:05 +0000 Subject: [PATCH 04/13] Refactor keycache to improve error handling - Use `cryptokeys.ErrKeyNotFound` for missing keys - Use `cryptokeys.ErrKeyInvalid` for invalid keys - Replace `xerrors` for streamlined error codes --- enterprise/wsproxy/keycache.go | 14 ++++---- enterprise/wsproxy/keycache_test.go | 53 ++++++++++++++++++++++++----- 2 files changed, 51 insertions(+), 16 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index cff7138b83f75..8883cb270e711 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -64,7 +64,7 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error latest := k.latest k.keysMu.RUnlock() - now := k.Clock.Now().UTC() + now := k.Clock.Now() if latest.CanSign(now) { return latest, nil } @@ -87,14 +87,14 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error } if !k.latest.CanSign(now) { - return codersdk.CryptoKey{}, xerrors.Errorf("no active keys found") + return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound } return k.latest, nil } func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { - now := k.Clock.Now().UTC() + now := k.Clock.Now() k.keysMu.RLock() if k.closed { k.keysMu.RUnlock() @@ -127,7 +127,7 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd key, ok = k.keys[sequence] if !ok { - return codersdk.CryptoKey{}, xerrors.Errorf("key %d not found", sequence) + return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound } return validKey(key, now) @@ -155,7 +155,7 @@ func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKe return nil, codersdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err) } - kmap, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now().UTC()) + kmap, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now()) return kmap, latest, nil } @@ -172,8 +172,8 @@ func toKeyMap(keys []codersdk.CryptoKey, now time.Time) (map[int32]codersdk.Cryp } func validKey(key codersdk.CryptoKey, now time.Time) (codersdk.CryptoKey, error) { - if !key.CanSign(now) { - return codersdk.CryptoKey{}, xerrors.Errorf("key %d is invalid", key.Sequence) + if !key.CanVerify(now) { + return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid } return key, nil diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index 7267c1da1e033..f141f1810bfc4 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -23,7 +23,7 @@ import ( func TestCryptoKeyCache(t *testing.T) { t.Parallel() - t.Run("Latest", func(t *testing.T) { + t.Run("Signing", func(t *testing.T) { t.Parallel() t.Run("HitsCache", func(t *testing.T) { @@ -138,9 +138,27 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, expected, got) require.Equal(t, 1, fc.called) }) + + t.Run("KeyNotFound", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + fc := newFakeCoderd(t, []codersdk.CryptoKey{}) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + _, err = cache.Verifying(ctx, 1) + require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) + }) }) - t.Run("Version", func(t *testing.T) { + t.Run("Verifying", func(t *testing.T) { t.Parallel() t.Run("HitsCache", func(t *testing.T) { @@ -241,7 +259,7 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, 1, fc.called) }) - t.Run("NoInvalid", func(t *testing.T) { + t.Run("KeyInvalid", func(t *testing.T) { t.Parallel() var ( @@ -267,9 +285,27 @@ func TestCryptoKeyCache(t *testing.T) { require.NoError(t, err) _, err = cache.Verifying(ctx, expected.Sequence) - require.Error(t, err) + require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid) require.Equal(t, 1, fc.called) }) + + t.Run("KeyNotFound", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + fc := newFakeCoderd(t, []codersdk.CryptoKey{}) + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + _, err = cache.Verifying(ctx, 1) + require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) + }) }) t.Run("CacheRefreshes", func(t *testing.T) { @@ -342,11 +378,10 @@ func TestCryptoKeyCache(t *testing.T) { now := clock.Now() expected := codersdk.CryptoKey{ - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 12, - StartsAt: now, - DeletesAt: now.Add(time.Minute * 10), + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, } fc := newFakeCoderd(t, []codersdk.CryptoKey{ expected, From 9061a1c1100d62c745ae8d3c00d62734be5d55cd Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 07:10:22 +0000 Subject: [PATCH 05/13] prevent cache slowdown during fetches --- enterprise/wsproxy/keycache.go | 118 +++++++++++++++++++-------------- 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index 8883cb270e711..51e8d94265385 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -3,6 +3,7 @@ package wsproxy import ( "context" "sync" + "sync/atomic" "time" "golang.org/x/xerrors" @@ -18,16 +19,19 @@ import ( var _ cryptokeys.Keycache = &CryptoKeyCache{} type CryptoKeyCache struct { - ctx context.Context - cancel context.CancelFunc - client *wsproxysdk.Client - logger slog.Logger - Clock quartz.Clock - - keysMu sync.RWMutex - keys map[int32]codersdk.CryptoKey - latest codersdk.CryptoKey - closed bool + refreshCtx context.Context + refreshCancel context.CancelFunc + client *wsproxysdk.Client + logger slog.Logger + Clock quartz.Clock + + keysMu sync.RWMutex + keys map[int32]codersdk.CryptoKey + latest codersdk.CryptoKey + fetchLock sync.RWMutex + lastFetch time.Time + refresher *quartz.Timer + closed atomic.Bool } func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.Client, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { @@ -46,21 +50,17 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk. return nil, xerrors.Errorf("initial fetch: %w", err) } cache.keys, cache.latest = m, latest - cache.ctx, cache.cancel = context.WithCancel(ctx) - - go cache.refresh() + cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh) return cache, nil } func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { - k.keysMu.RLock() - - if k.closed { - k.keysMu.RUnlock() + if k.isClosed() { return codersdk.CryptoKey{}, cryptokeys.ErrClosed } + k.keysMu.RLock() latest := k.latest k.keysMu.RUnlock() @@ -69,34 +69,31 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error return latest, nil } - k.keysMu.Lock() - defer k.keysMu.Unlock() + k.fetchLock.Lock() + defer k.fetchLock.Unlock() - if k.closed { + if k.isClosed() { return codersdk.CryptoKey{}, cryptokeys.ErrClosed } + k.keysMu.RLock() if k.latest.CanSign(now) { + k.keysMu.RUnlock() return k.latest, nil } - var err error - k.keys, k.latest, err = k.fetch(ctx) + _, latest, err := k.fetch(ctx) if err != nil { return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) } - if !k.latest.CanSign(now) { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound - } - - return k.latest, nil + return latest, nil } func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { now := k.Clock.Now() k.keysMu.RLock() - if k.closed { + if k.isClosed() { k.keysMu.RUnlock() return codersdk.CryptoKey{}, cryptokeys.ErrClosed } @@ -110,7 +107,7 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd k.keysMu.Lock() defer k.keysMu.Unlock() - if k.closed { + if k.isClosed() { return codersdk.CryptoKey{}, cryptokeys.ErrClosed } @@ -119,13 +116,12 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd return validKey(key, now) } - var err error - k.keys, k.latest, err = k.fetch(ctx) + keys, _, err := k.fetch(ctx) if err != nil { return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) } - key, ok = k.keys[sequence] + key, ok = keys[sequence] if !ok { return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound } @@ -134,28 +130,50 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd } func (k *CryptoKeyCache) refresh() { - k.Clock.TickerFunc(k.ctx, time.Minute*10, func() error { - kmap, latest, err := k.fetch(k.ctx) - if err != nil { - k.logger.Error(k.ctx, "failed to fetch crypto keys", slog.Error(err)) - return nil - } + if k.isClosed() { + return + } + + k.keysMu.RLock() + if k.Clock.Now().Sub(k.lastFetch) < time.Minute*10 { + k.keysMu.Unlock() + return + } + + k.fetchLock.Lock() + defer k.fetchLock.Unlock() - k.keysMu.Lock() - defer k.keysMu.Unlock() - k.keys = kmap - k.latest = latest - return nil - }) + _, _, err := k.fetch(k.refreshCtx) + if err != nil { + k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err)) + return + } } func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { + keys, err := k.client.CryptoKeys(ctx) if err != nil { return nil, codersdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err) } - kmap, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now()) + if len(keys.CryptoKeys) == 0 { + return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound + } + + now := k.Clock.Now() + kmap, latest := toKeyMap(keys.CryptoKeys, now) + if !latest.CanSign(now) { + return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid + } + + k.keysMu.Lock() + defer k.keysMu.Unlock() + + k.lastFetch = k.Clock.Now() + k.refresher.Reset(time.Minute * 10) + k.keys, k.latest = kmap, latest + return kmap, latest, nil } @@ -179,14 +197,18 @@ func validKey(key codersdk.CryptoKey, now time.Time) (codersdk.CryptoKey, error) return key, nil } +func (k *CryptoKeyCache) isClosed() bool { + return k.closed.Load() +} + func (k *CryptoKeyCache) Close() { k.keysMu.Lock() defer k.keysMu.Unlock() - if k.closed { + if k.isClosed() { return } - k.cancel() - k.closed = true + k.refreshCancel() + k.closed.Store(true) } From 0fef6b0569dc903bc1a0c4c3e6b8d57ef0705c8d Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 08:40:09 +0000 Subject: [PATCH 06/13] fix tests --- enterprise/wsproxy/keycache.go | 73 ++++++++++++++++++++--------- enterprise/wsproxy/keycache_test.go | 13 ++--- 2 files changed, 55 insertions(+), 31 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index 51e8d94265385..bdde6c769e9d2 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -45,12 +45,14 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk. opt(cache) } - m, latest, err := cache.fetch(ctx) + cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) + cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh) + m, latest, err := cache.fetchKeys(ctx) if err != nil { + cache.refreshCancel() return nil, xerrors.Errorf("initial fetch: %w", err) } cache.keys, cache.latest = m, latest - cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh) return cache, nil } @@ -77,9 +79,12 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error } k.keysMu.RLock() - if k.latest.CanSign(now) { - k.keysMu.RUnlock() - return k.latest, nil + latest = k.latest + k.keysMu.RUnlock() + + now = k.Clock.Now() + if latest.CanSign(now) { + return latest, nil } _, latest, err := k.fetch(ctx) @@ -91,27 +96,28 @@ func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error } func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { - now := k.Clock.Now() - k.keysMu.RLock() if k.isClosed() { - k.keysMu.RUnlock() return codersdk.CryptoKey{}, cryptokeys.ErrClosed } + now := k.Clock.Now() + k.keysMu.RLock() key, ok := k.keys[sequence] k.keysMu.RUnlock() if ok { return validKey(key, now) } - k.keysMu.Lock() - defer k.keysMu.Unlock() + k.fetchLock.Lock() + defer k.fetchLock.Unlock() if k.isClosed() { return codersdk.CryptoKey{}, cryptokeys.ErrClosed } + k.keysMu.RLock() key, ok = k.keys[sequence] + k.keysMu.RUnlock() if ok { return validKey(key, now) } @@ -134,14 +140,23 @@ func (k *CryptoKeyCache) refresh() { return } - k.keysMu.RLock() - if k.Clock.Now().Sub(k.lastFetch) < time.Minute*10 { - k.keysMu.Unlock() + k.fetchLock.Lock() + defer k.fetchLock.Unlock() + + if k.isClosed() { return } - k.fetchLock.Lock() - defer k.fetchLock.Unlock() + k.keysMu.RLock() + lastFetch := k.lastFetch + k.keysMu.RUnlock() + + // There's a window we must account for where the timer fires while a fetch + // is ongoing but prior to the timer getting reset. In this case we want to + // avoid double fetching. + if k.Clock.Now().Sub(lastFetch) < time.Minute*10 { + return + } _, _, err := k.fetch(k.refreshCtx) if err != nil { @@ -150,19 +165,28 @@ func (k *CryptoKeyCache) refresh() { } } -func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { - +func (k *CryptoKeyCache) fetchKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { keys, err := k.client.CryptoKeys(ctx) if err != nil { - return nil, codersdk.CryptoKey{}, xerrors.Errorf("get security keys: %w", err) + return nil, codersdk.CryptoKey{}, xerrors.Errorf("crypto keys: %w", err) } + cache, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now()) + return cache, latest, nil +} - if len(keys.CryptoKeys) == 0 { +// fetch fetches the keys from the control plane and updates the cache. The fetchMu +// must be held when calling this function to avoid multiple concurrent fetches. +func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { + keys, latest, err := k.fetchKeys(ctx) + if err != nil { + return nil, codersdk.CryptoKey{}, xerrors.Errorf("fetch keys: %w", err) + } + + if len(keys) == 0 { return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound } now := k.Clock.Now() - kmap, latest := toKeyMap(keys.CryptoKeys, now) if !latest.CanSign(now) { return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid } @@ -172,9 +196,9 @@ func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKe k.lastFetch = k.Clock.Now() k.refresher.Reset(time.Minute * 10) - k.keys, k.latest = kmap, latest + k.keys, k.latest = keys, latest - return kmap, latest, nil + return keys, latest, nil } func toKeyMap(keys []codersdk.CryptoKey, now time.Time) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey) { @@ -202,6 +226,11 @@ func (k *CryptoKeyCache) isClosed() bool { } func (k *CryptoKeyCache) Close() { + // The fetch lock must always be held before holding the keys lock + // otherwise we risk a deadlock. + k.fetchLock.Lock() + defer k.fetchLock.Unlock() + k.keysMu.Lock() defer k.keysMu.Unlock() diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index f141f1810bfc4..af66f0bb551c2 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -317,8 +317,6 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) - trap := clock.Trap().TickerFunc() - now := clock.Now().UTC() expected := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, @@ -339,8 +337,6 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, expected, got) require.Equal(t, 1, fc.called) - wait := trap.MustWait(ctx) - newKey := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, Secret: "key2", @@ -349,8 +345,6 @@ func TestCryptoKeyCache(t *testing.T) { } fc.keys = []codersdk.CryptoKey{newKey} - wait.Release() - // The ticker should fire and cause a request to coderd. _, advance := clock.AdvanceNext() advance.MustWait(ctx) @@ -362,9 +356,10 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, newKey, got) require.Equal(t, 2, fc.called) - // Assert we do not have the old key. - _, err = cache.Verifying(ctx, expected.Sequence) - require.Error(t, err) + // The ticker should fire and cause a request to coderd. + _, advance = clock.AdvanceNext() + advance.MustWait(ctx) + require.Equal(t, 3, fc.called) }) t.Run("Closed", func(t *testing.T) { From a801d0cc244f93ff045b824fa86fe2fa72ad25a4 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 22:16:52 +0000 Subject: [PATCH 07/13] craft test for race condition --- enterprise/wsproxy/keycache.go | 24 +++++++---- enterprise/wsproxy/keycache_test.go | 67 ++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index bdde6c769e9d2..3d6456229b4c0 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -2,6 +2,7 @@ package wsproxy import ( "context" + "maps" "sync" "sync/atomic" "time" @@ -47,7 +48,7 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk. cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh) - m, latest, err := cache.fetchKeys(ctx) + m, latest, err := cache.cryptoKeys(ctx) if err != nil { cache.refreshCancel() return nil, xerrors.Errorf("initial fetch: %w", err) @@ -100,10 +101,11 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd return codersdk.CryptoKey{}, cryptokeys.ErrClosed } - now := k.Clock.Now() k.keysMu.RLock() key, ok := k.keys[sequence] k.keysMu.RUnlock() + + now := k.Clock.Now() if ok { return validKey(key, now) } @@ -135,11 +137,13 @@ func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersd return validKey(key, now) } +// refresh fetches the keys from the control plane and updates the cache. func (k *CryptoKeyCache) refresh() { if k.isClosed() { return } + now := k.Clock.Now("CryptoKeyCache", "refresh") k.fetchLock.Lock() defer k.fetchLock.Unlock() @@ -154,7 +158,7 @@ func (k *CryptoKeyCache) refresh() { // There's a window we must account for where the timer fires while a fetch // is ongoing but prior to the timer getting reset. In this case we want to // avoid double fetching. - if k.Clock.Now().Sub(lastFetch) < time.Minute*10 { + if now.Sub(lastFetch) < time.Minute*10 { return } @@ -165,7 +169,9 @@ func (k *CryptoKeyCache) refresh() { } } -func (k *CryptoKeyCache) fetchKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { +// cryptoKeys queries the control plane for the crypto keys. +// Outside of initialization, this should only be called by fetch. +func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { keys, err := k.client.CryptoKeys(ctx) if err != nil { return nil, codersdk.CryptoKey{}, xerrors.Errorf("crypto keys: %w", err) @@ -176,8 +182,9 @@ func (k *CryptoKeyCache) fetchKeys(ctx context.Context) (map[int32]codersdk.Cryp // fetch fetches the keys from the control plane and updates the cache. The fetchMu // must be held when calling this function to avoid multiple concurrent fetches. +// The returned keys are safe to use without additional locking. func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { - keys, latest, err := k.fetchKeys(ctx) + keys, latest, err := k.cryptoKeys(ctx) if err != nil { return nil, codersdk.CryptoKey{}, xerrors.Errorf("fetch keys: %w", err) } @@ -196,7 +203,7 @@ func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKe k.lastFetch = k.Clock.Now() k.refresher.Reset(time.Minute * 10) - k.keys, k.latest = keys, latest + k.keys, k.latest = maps.Clone(keys), latest return keys, latest, nil } @@ -226,8 +233,8 @@ func (k *CryptoKeyCache) isClosed() bool { } func (k *CryptoKeyCache) Close() { - // The fetch lock must always be held before holding the keys lock - // otherwise we risk a deadlock. + // It's important to hold the locks here so that we don't unintentionally + // reset the timer via an in flight request when Close is called. k.fetchLock.Lock() defer k.fetchLock.Unlock() @@ -239,5 +246,6 @@ func (k *CryptoKeyCache) Close() { } k.refreshCancel() + k.refresher.Stop() k.closed.Store(true) } diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index af66f0bb551c2..7af9259619949 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/stretchr/testify/require" + "go.uber.org/goleak" "cdr.dev/slog/sloggers/slogtest" @@ -20,6 +21,10 @@ import ( "github.com/coder/quartz" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + func TestCryptoKeyCache(t *testing.T) { t.Parallel() @@ -346,9 +351,10 @@ func TestCryptoKeyCache(t *testing.T) { fc.keys = []codersdk.CryptoKey{newKey} // The ticker should fire and cause a request to coderd. - _, advance := clock.AdvanceNext() + dur, advance := clock.AdvanceNext() advance.MustWait(ctx) require.Equal(t, 2, fc.called) + require.Equal(t, time.Minute*10, dur) // Assert hits cache. got, err = cache.Signing(ctx) @@ -356,10 +362,67 @@ func TestCryptoKeyCache(t *testing.T) { require.Equal(t, newKey, got) require.Equal(t, 2, fc.called) - // The ticker should fire and cause a request to coderd. + // We check again to ensure the timer has been reset. _, advance = clock.AdvanceNext() advance.MustWait(ctx) require.Equal(t, 3, fc.called) + require.Equal(t, time.Minute*10, dur) + }) + + // This test ensures that if the refresh timer races with an inflight request + // and loses that it doesn't cause a redundant fetch. + + t.Run("RefreshNoDoubleFetch", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + DeletesAt: now.Add(time.Minute * 10), + } + fc := newFakeCoderd(t, []codersdk.CryptoKey{ + expected, + }) + + // Create a trap that blocks when the refresh timer fires. + trap := clock.Trap().Now("refresh") + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + require.NoError(t, err) + + _, wait := clock.AdvanceNext() + trapped := trap.MustWait(ctx) + + newKey := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + } + fc.keys = []codersdk.CryptoKey{newKey} + + _, err = cache.Verifying(ctx, newKey.Sequence) + require.NoError(t, err) + require.Equal(t, 2, fc.called) + + trapped.Release() + wait.MustWait(ctx) + require.Equal(t, 2, fc.called) + trap.Close() + + // The next timer should fire in 10 minutes. + dur, wait := clock.AdvanceNext() + wait.MustWait(ctx) + require.Equal(t, time.Minute*10, dur) + require.Equal(t, 3, fc.called) }) t.Run("Closed", func(t *testing.T) { From 2ce48765512aa7c5356c121e0245e2faefc666b8 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 22:53:40 +0000 Subject: [PATCH 08/13] remove gotestleak...too many unrelated errors --- enterprise/wsproxy/keycache_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index 7af9259619949..ec9dbb7fa00d6 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -9,7 +9,6 @@ import ( "time" "github.com/stretchr/testify/require" - "go.uber.org/goleak" "cdr.dev/slog/sloggers/slogtest" @@ -21,10 +20,6 @@ import ( "github.com/coder/quartz" ) -func TestMain(m *testing.M) { - goleak.VerifyTestMain(m) -} - func TestCryptoKeyCache(t *testing.T) { t.Parallel() From 5eedbf885a85a03ccc87c671840db50e1658dcc4 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Wed, 2 Oct 2024 23:01:19 +0000 Subject: [PATCH 09/13] optimize Close --- enterprise/wsproxy/keycache.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index 3d6456229b4c0..a30f43e2620d8 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -233,19 +233,20 @@ func (k *CryptoKeyCache) isClosed() bool { } func (k *CryptoKeyCache) Close() { + if k.isClosed() { + return + } + + k.refreshCancel() + k.closed.Store(true) + // It's important to hold the locks here so that we don't unintentionally - // reset the timer via an in flight request when Close is called. + // reset the timer via an in flight request after Close returns. k.fetchLock.Lock() defer k.fetchLock.Unlock() k.keysMu.Lock() defer k.keysMu.Unlock() - if k.isClosed() { - return - } - - k.refreshCancel() k.refresher.Stop() - k.closed.Store(true) } From 6c2be2c740e28bc2423fb99b9243891d66f7f9ea Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 4 Oct 2024 06:46:51 +0000 Subject: [PATCH 10/13] Refactor CryptoKeyCache to improve concurrency - Replaced sync/atomic with sync.Mutex and sync.Cond for better concurrency control, avoiding frequent lock contention and reducing potential for data races. - Removed separate fetch logic and integrated it into the main crypto key retrieval flow to handle concurrent fetches properly. - Optimized the Close method to minimize wait time by broadcasting on cond to notify waiting fetches when closed. --- enterprise/wsproxy/keycache.go | 229 +++++++++++++--------------- enterprise/wsproxy/keycache_test.go | 221 ++++++++++++--------------- 2 files changed, 203 insertions(+), 247 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index a30f43e2620d8..6f6db5cd9769c 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -2,9 +2,7 @@ package wsproxy import ( "context" - "maps" "sync" - "sync/atomic" "time" "golang.org/x/xerrors" @@ -13,41 +11,54 @@ import ( "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" "github.com/coder/quartz" ) +const ( + // latestSequence is a special sequence number that represents the latest key. + latestSequence = -1 + // refreshInterval is the interval at which the key cache will refresh. + refreshInterval = time.Minute * 10 +) + +type Fetcher interface { + Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) +} + var _ cryptokeys.Keycache = &CryptoKeyCache{} type CryptoKeyCache struct { + Clock quartz.Clock refreshCtx context.Context refreshCancel context.CancelFunc - client *wsproxysdk.Client + fetcher Fetcher logger slog.Logger - Clock quartz.Clock - keysMu sync.RWMutex + mu sync.Mutex keys map[int32]codersdk.CryptoKey latest codersdk.CryptoKey - fetchLock sync.RWMutex lastFetch time.Time refresher *quartz.Timer - closed atomic.Bool + fetching bool + closed bool + cond *sync.Cond } -func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk.Client, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { +func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client Fetcher, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { cache := &CryptoKeyCache{ - client: client, - logger: log, - Clock: quartz.NewReal(), + Clock: quartz.NewReal(), + logger: log, + fetcher: client, } for _, opt := range opts { opt(cache) } + cache.cond = sync.NewCond(&cache.mu) cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) - cache.refresher = cache.Clock.AfterFunc(time.Minute*10, cache.refresh) + cache.refresher = cache.Clock.AfterFunc(refreshInterval, cache.refresh) + m, latest, err := cache.cryptoKeys(ctx) if err != nil { cache.refreshCancel() @@ -59,155 +70,136 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client *wsproxysdk. } func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { - if k.isClosed() { - return codersdk.CryptoKey{}, cryptokeys.ErrClosed - } - - k.keysMu.RLock() - latest := k.latest - k.keysMu.RUnlock() + return k.cryptoKey(ctx, latestSequence) +} - now := k.Clock.Now() - if latest.CanSign(now) { - return latest, nil - } +func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { + return k.cryptoKey(ctx, sequence) +} - k.fetchLock.Lock() - defer k.fetchLock.Unlock() +func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { + k.mu.Lock() + defer k.mu.Unlock() - if k.isClosed() { + if k.closed { return codersdk.CryptoKey{}, cryptokeys.ErrClosed } - k.keysMu.RLock() - latest = k.latest - k.keysMu.RUnlock() + var key codersdk.CryptoKey + var ok bool + for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; { + k.cond.Wait() + } - now = k.Clock.Now() - if latest.CanSign(now) { - return latest, nil + if k.closed { + return codersdk.CryptoKey{}, cryptokeys.ErrClosed } - _, latest, err := k.fetch(ctx) - if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) + if ok { + return checkKey(key, sequence, k.Clock.Now()) } - return latest, nil -} + k.fetching = true + k.mu.Unlock() -func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { - if k.isClosed() { - return codersdk.CryptoKey{}, cryptokeys.ErrClosed + keys, latest, err := k.cryptoKeys(ctx) + if err != nil { + return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err) } - k.keysMu.RLock() - key, ok := k.keys[sequence] - k.keysMu.RUnlock() + k.mu.Lock() + k.lastFetch = k.Clock.Now() + k.refresher.Reset(refreshInterval) + k.keys, k.latest = keys, latest + k.fetching = false + k.cond.Broadcast() - now := k.Clock.Now() - if ok { - return validKey(key, now) + key, ok = k.key(sequence) + if !ok { + return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound } - k.fetchLock.Lock() - defer k.fetchLock.Unlock() + return checkKey(key, sequence, k.Clock.Now()) +} - if k.isClosed() { - return codersdk.CryptoKey{}, cryptokeys.ErrClosed +func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) { + if sequence == latestSequence { + return k.latest, k.latest.CanSign(k.Clock.Now()) } - k.keysMu.RLock() - key, ok = k.keys[sequence] - k.keysMu.RUnlock() - if ok { - return validKey(key, now) - } + key, ok := k.keys[sequence] + return key, ok +} - keys, _, err := k.fetch(ctx) - if err != nil { - return codersdk.CryptoKey{}, xerrors.Errorf("fetch: %w", err) +func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) { + if sequence == latestSequence { + if !key.CanSign(now) { + return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid + } + return key, nil } - key, ok = keys[sequence] - if !ok { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound + if !key.CanVerify(now) { + return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid } - return validKey(key, now) + return key, nil } // refresh fetches the keys from the control plane and updates the cache. func (k *CryptoKeyCache) refresh() { - if k.isClosed() { + if k.closed { return } now := k.Clock.Now("CryptoKeyCache", "refresh") - k.fetchLock.Lock() - defer k.fetchLock.Unlock() - if k.isClosed() { + k.mu.Lock() + + // If something's already fetching, we don't need to do anything. + if k.fetching { + k.mu.Unlock() return } - k.keysMu.RLock() - lastFetch := k.lastFetch - k.keysMu.RUnlock() - // There's a window we must account for where the timer fires while a fetch // is ongoing but prior to the timer getting reset. In this case we want to // avoid double fetching. - if now.Sub(lastFetch) < time.Minute*10 { + if now.Sub(k.lastFetch) < refreshInterval { + k.mu.Unlock() return } - _, _, err := k.fetch(k.refreshCtx) + k.fetching = true + + k.mu.Unlock() + keys, latest, err := k.cryptoKeys(k.refreshCtx) if err != nil { k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err)) return } + + k.mu.Lock() + defer k.mu.Unlock() + + k.lastFetch = k.Clock.Now() + k.refresher.Reset(refreshInterval) + k.keys, k.latest = keys, latest + k.fetching = false + k.cond.Broadcast() } // cryptoKeys queries the control plane for the crypto keys. // Outside of initialization, this should only be called by fetch. func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { - keys, err := k.client.CryptoKeys(ctx) + keys, err := k.fetcher.Fetch(ctx) if err != nil { return nil, codersdk.CryptoKey{}, xerrors.Errorf("crypto keys: %w", err) } - cache, latest := toKeyMap(keys.CryptoKeys, k.Clock.Now()) + cache, latest := toKeyMap(keys, k.Clock.Now()) return cache, latest, nil } -// fetch fetches the keys from the control plane and updates the cache. The fetchMu -// must be held when calling this function to avoid multiple concurrent fetches. -// The returned keys are safe to use without additional locking. -func (k *CryptoKeyCache) fetch(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { - keys, latest, err := k.cryptoKeys(ctx) - if err != nil { - return nil, codersdk.CryptoKey{}, xerrors.Errorf("fetch keys: %w", err) - } - - if len(keys) == 0 { - return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound - } - - now := k.Clock.Now() - if !latest.CanSign(now) { - return nil, codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid - } - - k.keysMu.Lock() - defer k.keysMu.Unlock() - - k.lastFetch = k.Clock.Now() - k.refresher.Reset(time.Minute * 10) - k.keys, k.latest = maps.Clone(keys), latest - - return keys, latest, nil -} - func toKeyMap(keys []codersdk.CryptoKey, now time.Time) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey) { m := make(map[int32]codersdk.CryptoKey) var latest codersdk.CryptoKey @@ -220,33 +212,16 @@ func toKeyMap(keys []codersdk.CryptoKey, now time.Time) (map[int32]codersdk.Cryp return m, latest } -func validKey(key codersdk.CryptoKey, now time.Time) (codersdk.CryptoKey, error) { - if !key.CanVerify(now) { - return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid - } - - return key, nil -} - -func (k *CryptoKeyCache) isClosed() bool { - return k.closed.Load() -} - func (k *CryptoKeyCache) Close() { - if k.isClosed() { + k.mu.Lock() + defer k.mu.Unlock() + + if k.closed { return } + k.closed = true k.refreshCancel() - k.closed.Store(true) - - // It's important to hold the locks here so that we don't unintentionally - // reset the timer via an in flight request after Close returns. - k.fetchLock.Lock() - defer k.fetchLock.Unlock() - - k.keysMu.Lock() - defer k.keysMu.Unlock() - k.refresher.Stop() + k.cond.Broadcast() } diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index ec9dbb7fa00d6..10e171c12e017 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -1,10 +1,7 @@ package wsproxy_test import ( - "encoding/json" - "net/http" - "net/http/httptest" - "net/url" + "context" "testing" "time" @@ -15,7 +12,6 @@ import ( "github.com/coder/coder/v2/coderd/cryptokeys" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/wsproxy" - "github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk" "github.com/coder/coder/v2/testutil" "github.com/coder/quartz" ) @@ -42,30 +38,17 @@ func TestCryptoKeyCache(t *testing.T) { StartsAt: now, } - fc := newFakeCoderd(t, []codersdk.CryptoKey{ - { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key1", - Sequence: 1, - StartsAt: now, - }, - // Should be ignored since it hasn't breached its starts_at time yet. - { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key3", - Sequence: 3, - StartsAt: now.Add(time.Second * 2), - }, - expected, - }) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{expected}, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 1, fc.called) + require.Equal(t, 1, ff.called) }) t.Run("MissesCache", func(t *testing.T) { @@ -76,9 +59,11 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) - fc := newFakeCoderd(t, []codersdk.CryptoKey{}) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) expected := codersdk.CryptoKey{ @@ -87,20 +72,20 @@ func TestCryptoKeyCache(t *testing.T) { Sequence: 12, StartsAt: clock.Now().UTC(), } - fc.keys = []codersdk.CryptoKey{expected} + ff.keys = []codersdk.CryptoKey{expected} got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) // 1 on startup + missing cache. - require.Equal(t, 2, fc.called) + require.Equal(t, 2, ff.called) // Ensure the cache gets hit this time. got, err = cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) // 1 on startup + missing cache. - require.Equal(t, 2, fc.called) + require.Equal(t, 2, ff.called) }) t.Run("IgnoresInvalid", func(t *testing.T) { @@ -119,24 +104,26 @@ func TestCryptoKeyCache(t *testing.T) { StartsAt: clock.Now().UTC(), } - fc := newFakeCoderd(t, []codersdk.CryptoKey{ - expected, - { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", - Sequence: 2, - StartsAt: now.Add(-time.Second), - DeletesAt: now, + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + { + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 2, + StartsAt: now.Add(-time.Second), + DeletesAt: now, + }, }, - }) + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 1, fc.called) + require.Equal(t, 1, ff.called) }) t.Run("KeyNotFound", func(t *testing.T) { @@ -148,12 +135,14 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) - fc := newFakeCoderd(t, []codersdk.CryptoKey{}) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) - _, err = cache.Verifying(ctx, 1) + _, err = cache.Signing(ctx) require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) }) }) @@ -177,23 +166,25 @@ func TestCryptoKeyCache(t *testing.T) { Sequence: 12, StartsAt: now, } - fc := newFakeCoderd(t, []codersdk.CryptoKey{ - expected, - { - Feature: codersdk.CryptoKeyFeatureWorkspaceApp, - Secret: "key2", - Sequence: 13, - StartsAt: now, + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + { + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + }, }, - }) + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) got, err := cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 1, fc.called) + require.Equal(t, 1, ff.called) }) t.Run("MissesCache", func(t *testing.T) { @@ -204,9 +195,11 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) - fc := newFakeCoderd(t, []codersdk.CryptoKey{}) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) expected := codersdk.CryptoKey{ @@ -215,18 +208,18 @@ func TestCryptoKeyCache(t *testing.T) { Sequence: 12, StartsAt: clock.Now().UTC(), } - fc.keys = []codersdk.CryptoKey{expected} + ff.keys = []codersdk.CryptoKey{expected} got, err := cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 2, fc.called) + require.Equal(t, 2, ff.called) // Ensure the cache gets hit this time. got, err = cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 2, fc.called) + require.Equal(t, 2, ff.called) }) t.Run("AllowsBeforeStartsAt", func(t *testing.T) { @@ -246,17 +239,19 @@ func TestCryptoKeyCache(t *testing.T) { StartsAt: now.Add(-time.Second), } - fc := newFakeCoderd(t, []codersdk.CryptoKey{ - expected, - }) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) got, err := cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 1, fc.called) + require.Equal(t, 1, ff.called) }) t.Run("KeyInvalid", func(t *testing.T) { @@ -277,16 +272,18 @@ func TestCryptoKeyCache(t *testing.T) { DeletesAt: now, } - fc := newFakeCoderd(t, []codersdk.CryptoKey{ - expected, - }) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) _, err = cache.Verifying(ctx, expected.Sequence) require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid) - require.Equal(t, 1, fc.called) + require.Equal(t, 1, ff.called) }) t.Run("KeyNotFound", func(t *testing.T) { @@ -298,9 +295,11 @@ func TestCryptoKeyCache(t *testing.T) { clock = quartz.NewMock(t) ) - fc := newFakeCoderd(t, []codersdk.CryptoKey{}) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) _, err = cache.Verifying(ctx, 1) @@ -325,17 +324,19 @@ func TestCryptoKeyCache(t *testing.T) { StartsAt: now, DeletesAt: now.Add(time.Minute * 10), } - fc := newFakeCoderd(t, []codersdk.CryptoKey{ - expected, - }) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 1, fc.called) + require.Equal(t, 1, ff.called) newKey := codersdk.CryptoKey{ Feature: codersdk.CryptoKeyFeatureWorkspaceApp, @@ -343,24 +344,24 @@ func TestCryptoKeyCache(t *testing.T) { Sequence: 13, StartsAt: now, } - fc.keys = []codersdk.CryptoKey{newKey} + ff.keys = []codersdk.CryptoKey{newKey} // The ticker should fire and cause a request to coderd. dur, advance := clock.AdvanceNext() advance.MustWait(ctx) - require.Equal(t, 2, fc.called) + require.Equal(t, 2, ff.called) require.Equal(t, time.Minute*10, dur) // Assert hits cache. got, err = cache.Signing(ctx) require.NoError(t, err) require.Equal(t, newKey, got) - require.Equal(t, 2, fc.called) + require.Equal(t, 2, ff.called) // We check again to ensure the timer has been reset. _, advance = clock.AdvanceNext() advance.MustWait(ctx) - require.Equal(t, 3, fc.called) + require.Equal(t, 3, ff.called) require.Equal(t, time.Minute*10, dur) }) @@ -384,13 +385,15 @@ func TestCryptoKeyCache(t *testing.T) { StartsAt: now, DeletesAt: now.Add(time.Minute * 10), } - fc := newFakeCoderd(t, []codersdk.CryptoKey{ - expected, - }) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } // Create a trap that blocks when the refresh timer fires. trap := clock.Trap().Now("refresh") - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) _, wait := clock.AdvanceNext() @@ -402,22 +405,22 @@ func TestCryptoKeyCache(t *testing.T) { Sequence: 13, StartsAt: now, } - fc.keys = []codersdk.CryptoKey{newKey} + ff.keys = []codersdk.CryptoKey{newKey} _, err = cache.Verifying(ctx, newKey.Sequence) require.NoError(t, err) - require.Equal(t, 2, fc.called) + require.Equal(t, 2, ff.called) trapped.Release() wait.MustWait(ctx) - require.Equal(t, 2, fc.called) + require.Equal(t, 2, ff.called) trap.Close() // The next timer should fire in 10 minutes. dur, wait := clock.AdvanceNext() wait.MustWait(ctx) require.Equal(t, time.Minute*10, dur) - require.Equal(t, 3, fc.called) + require.Equal(t, 3, ff.called) }) t.Run("Closed", func(t *testing.T) { @@ -436,22 +439,24 @@ func TestCryptoKeyCache(t *testing.T) { Sequence: 12, StartsAt: now, } - fc := newFakeCoderd(t, []codersdk.CryptoKey{ - expected, - }) + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } - cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, wsproxysdk.New(fc.url), withClock(clock)) + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) require.NoError(t, err) got, err := cache.Signing(ctx) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 1, fc.called) + require.Equal(t, 1, ff.called) got, err = cache.Verifying(ctx, expected.Sequence) require.NoError(t, err) require.Equal(t, expected, got) - require.Equal(t, 1, fc.called) + require.Equal(t, 1, ff.called) cache.Close() @@ -463,38 +468,14 @@ func TestCryptoKeyCache(t *testing.T) { }) } -type fakeCoderd struct { - server *httptest.Server +type fakeFetcher struct { keys []codersdk.CryptoKey called int - url *url.URL } -func newFakeCoderd(t *testing.T, keys []codersdk.CryptoKey) *fakeCoderd { - t.Helper() - - c := &fakeCoderd{ - keys: keys, - } - - mux := http.NewServeMux() - mux.HandleFunc("/api/v2/workspaceproxies/me/crypto-keys", func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - err := json.NewEncoder(w).Encode(wsproxysdk.CryptoKeysResponse{ - CryptoKeys: c.keys, - }) - require.NoError(t, err) - c.called++ - }) - - c.server = httptest.NewServer(mux) - t.Cleanup(c.server.Close) - - var err error - c.url, err = url.Parse(c.server.URL) - require.NoError(t, err) - - return c +func (f *fakeFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { + f.called++ + return f.keys, nil } func withClock(clock quartz.Clock) func(*wsproxy.CryptoKeyCache) { From a1d4b46917aed142dd879b5a6cccc97d82322825 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 4 Oct 2024 06:51:00 +0000 Subject: [PATCH 11/13] Refactor key refresh to minimize lock duration --- enterprise/wsproxy/keycache.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index 6f6db5cd9769c..4d5db6070ffb1 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -148,14 +148,14 @@ func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.C // refresh fetches the keys from the control plane and updates the cache. func (k *CryptoKeyCache) refresh() { + now := k.Clock.Now("CryptoKeyCache", "refresh") + k.mu.Lock() + if k.closed { + k.mu.Unlock() return } - now := k.Clock.Now("CryptoKeyCache", "refresh") - - k.mu.Lock() - // If something's already fetching, we don't need to do anything. if k.fetching { k.mu.Unlock() From b33616e6e4c5e70802ce86f6f26661bfaedd292c Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 4 Oct 2024 06:56:24 +0000 Subject: [PATCH 12/13] Remove unused interface implementation check The unused `cryptokeys.Keycache` interface implementation check was removed from the `CryptoKeyCache` struct, as it was unnecessary and did not contribute to functionality. This streamlines the code and adheres to best practices for interface checks. --- enterprise/wsproxy/keycache.go | 2 -- enterprise/wsproxy/keycache_test.go | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index 4d5db6070ffb1..5fdd9c356fa85 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -25,8 +25,6 @@ type Fetcher interface { Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) } -var _ cryptokeys.Keycache = &CryptoKeyCache{} - type CryptoKeyCache struct { Clock quartz.Clock refreshCtx context.Context diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go index 10e171c12e017..210e04f9edf76 100644 --- a/enterprise/wsproxy/keycache_test.go +++ b/enterprise/wsproxy/keycache_test.go @@ -473,7 +473,7 @@ type fakeFetcher struct { called int } -func (f *fakeFetcher) Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) { +func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) { f.called++ return f.keys, nil } From ee77d8ebb81fe770a99208f9919eae1e76648038 Mon Sep 17 00:00:00 2001 From: Jon Ayers Date: Fri, 4 Oct 2024 15:58:53 +0000 Subject: [PATCH 13/13] Remove redundant CryptoKey cache field The `latest` field in `CryptoKeyCache` was redundant and has been removed to simplify the code. Instead, the latest crypto key is now directly accessed from the `keys` map using `latestSequence`. This change reduces unnecessary state management and potential for data inconsistency. --- enterprise/wsproxy/keycache.go | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go index 5fdd9c356fa85..a877b9757d250 100644 --- a/enterprise/wsproxy/keycache.go +++ b/enterprise/wsproxy/keycache.go @@ -34,7 +34,6 @@ type CryptoKeyCache struct { mu sync.Mutex keys map[int32]codersdk.CryptoKey - latest codersdk.CryptoKey lastFetch time.Time refresher *quartz.Timer fetching bool @@ -57,12 +56,12 @@ func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client Fetcher, opt cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) cache.refresher = cache.Clock.AfterFunc(refreshInterval, cache.refresh) - m, latest, err := cache.cryptoKeys(ctx) + keys, err := cache.cryptoKeys(ctx) if err != nil { cache.refreshCancel() return nil, xerrors.Errorf("initial fetch: %w", err) } - cache.keys, cache.latest = m, latest + cache.keys = keys return cache, nil } @@ -100,7 +99,7 @@ func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersd k.fetching = true k.mu.Unlock() - keys, latest, err := k.cryptoKeys(ctx) + keys, err := k.cryptoKeys(ctx) if err != nil { return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err) } @@ -108,7 +107,7 @@ func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersd k.mu.Lock() k.lastFetch = k.Clock.Now() k.refresher.Reset(refreshInterval) - k.keys, k.latest = keys, latest + k.keys = keys k.fetching = false k.cond.Broadcast() @@ -122,7 +121,7 @@ func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersd func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) { if sequence == latestSequence { - return k.latest, k.latest.CanSign(k.Clock.Now()) + return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.Clock.Now()) } key, ok := k.keys[sequence] @@ -171,7 +170,7 @@ func (k *CryptoKeyCache) refresh() { k.fetching = true k.mu.Unlock() - keys, latest, err := k.cryptoKeys(k.refreshCtx) + keys, err := k.cryptoKeys(k.refreshCtx) if err != nil { k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err)) return @@ -182,32 +181,32 @@ func (k *CryptoKeyCache) refresh() { k.lastFetch = k.Clock.Now() k.refresher.Reset(refreshInterval) - k.keys, k.latest = keys, latest + k.keys = keys k.fetching = false k.cond.Broadcast() } // cryptoKeys queries the control plane for the crypto keys. // Outside of initialization, this should only be called by fetch. -func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey, error) { +func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { keys, err := k.fetcher.Fetch(ctx) if err != nil { - return nil, codersdk.CryptoKey{}, xerrors.Errorf("crypto keys: %w", err) + return nil, xerrors.Errorf("crypto keys: %w", err) } - cache, latest := toKeyMap(keys, k.Clock.Now()) - return cache, latest, nil + cache := toKeyMap(keys, k.Clock.Now()) + return cache, nil } -func toKeyMap(keys []codersdk.CryptoKey, now time.Time) (map[int32]codersdk.CryptoKey, codersdk.CryptoKey) { +func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey { m := make(map[int32]codersdk.CryptoKey) var latest codersdk.CryptoKey for _, key := range keys { m[key.Sequence] = key if key.Sequence > latest.Sequence && key.CanSign(now) { - latest = key + m[latestSequence] = key } } - return m, latest + return m } func (k *CryptoKeyCache) Close() {