diff --git a/enterprise/wsproxy/keycache.go b/enterprise/wsproxy/keycache.go new file mode 100644 index 0000000000000..a877b9757d250 --- /dev/null +++ b/enterprise/wsproxy/keycache.go @@ -0,0 +1,224 @@ +package wsproxy + +import ( + "context" + "sync" + "time" + + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/quartz" +) + +const ( + // latestSequence is a special sequence number that represents the latest key. + latestSequence = -1 + // refreshInterval is the interval at which the key cache will refresh. + refreshInterval = time.Minute * 10 +) + +type Fetcher interface { + Fetch(ctx context.Context) ([]codersdk.CryptoKey, error) +} + +type CryptoKeyCache struct { + Clock quartz.Clock + refreshCtx context.Context + refreshCancel context.CancelFunc + fetcher Fetcher + logger slog.Logger + + mu sync.Mutex + keys map[int32]codersdk.CryptoKey + lastFetch time.Time + refresher *quartz.Timer + fetching bool + closed bool + cond *sync.Cond +} + +func NewCryptoKeyCache(ctx context.Context, log slog.Logger, client Fetcher, opts ...func(*CryptoKeyCache)) (*CryptoKeyCache, error) { + cache := &CryptoKeyCache{ + Clock: quartz.NewReal(), + logger: log, + fetcher: client, + } + + for _, opt := range opts { + opt(cache) + } + + cache.cond = sync.NewCond(&cache.mu) + cache.refreshCtx, cache.refreshCancel = context.WithCancel(ctx) + cache.refresher = cache.Clock.AfterFunc(refreshInterval, cache.refresh) + + keys, err := cache.cryptoKeys(ctx) + if err != nil { + cache.refreshCancel() + return nil, xerrors.Errorf("initial fetch: %w", err) + } + cache.keys = keys + + return cache, nil +} + +func (k *CryptoKeyCache) Signing(ctx context.Context) (codersdk.CryptoKey, error) { + return k.cryptoKey(ctx, latestSequence) +} + +func (k *CryptoKeyCache) Verifying(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { + return k.cryptoKey(ctx, sequence) +} + +func (k *CryptoKeyCache) cryptoKey(ctx context.Context, sequence int32) (codersdk.CryptoKey, error) { + k.mu.Lock() + defer k.mu.Unlock() + + if k.closed { + return codersdk.CryptoKey{}, cryptokeys.ErrClosed + } + + var key codersdk.CryptoKey + var ok bool + for key, ok = k.key(sequence); !ok && k.fetching && !k.closed; { + k.cond.Wait() + } + + if k.closed { + return codersdk.CryptoKey{}, cryptokeys.ErrClosed + } + + if ok { + return checkKey(key, sequence, k.Clock.Now()) + } + + k.fetching = true + k.mu.Unlock() + + keys, err := k.cryptoKeys(ctx) + if err != nil { + return codersdk.CryptoKey{}, xerrors.Errorf("get keys: %w", err) + } + + k.mu.Lock() + k.lastFetch = k.Clock.Now() + k.refresher.Reset(refreshInterval) + k.keys = keys + k.fetching = false + k.cond.Broadcast() + + key, ok = k.key(sequence) + if !ok { + return codersdk.CryptoKey{}, cryptokeys.ErrKeyNotFound + } + + return checkKey(key, sequence, k.Clock.Now()) +} + +func (k *CryptoKeyCache) key(sequence int32) (codersdk.CryptoKey, bool) { + if sequence == latestSequence { + return k.keys[latestSequence], k.keys[latestSequence].CanSign(k.Clock.Now()) + } + + key, ok := k.keys[sequence] + return key, ok +} + +func checkKey(key codersdk.CryptoKey, sequence int32, now time.Time) (codersdk.CryptoKey, error) { + if sequence == latestSequence { + if !key.CanSign(now) { + return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid + } + return key, nil + } + + if !key.CanVerify(now) { + return codersdk.CryptoKey{}, cryptokeys.ErrKeyInvalid + } + + return key, nil +} + +// refresh fetches the keys from the control plane and updates the cache. +func (k *CryptoKeyCache) refresh() { + now := k.Clock.Now("CryptoKeyCache", "refresh") + k.mu.Lock() + + if k.closed { + k.mu.Unlock() + return + } + + // If something's already fetching, we don't need to do anything. + if k.fetching { + k.mu.Unlock() + return + } + + // There's a window we must account for where the timer fires while a fetch + // is ongoing but prior to the timer getting reset. In this case we want to + // avoid double fetching. + if now.Sub(k.lastFetch) < refreshInterval { + k.mu.Unlock() + return + } + + k.fetching = true + + k.mu.Unlock() + keys, err := k.cryptoKeys(k.refreshCtx) + if err != nil { + k.logger.Error(k.refreshCtx, "fetch crypto keys", slog.Error(err)) + return + } + + k.mu.Lock() + defer k.mu.Unlock() + + k.lastFetch = k.Clock.Now() + k.refresher.Reset(refreshInterval) + k.keys = keys + k.fetching = false + k.cond.Broadcast() +} + +// cryptoKeys queries the control plane for the crypto keys. +// Outside of initialization, this should only be called by fetch. +func (k *CryptoKeyCache) cryptoKeys(ctx context.Context) (map[int32]codersdk.CryptoKey, error) { + keys, err := k.fetcher.Fetch(ctx) + if err != nil { + return nil, xerrors.Errorf("crypto keys: %w", err) + } + cache := toKeyMap(keys, k.Clock.Now()) + return cache, nil +} + +func toKeyMap(keys []codersdk.CryptoKey, now time.Time) map[int32]codersdk.CryptoKey { + m := make(map[int32]codersdk.CryptoKey) + var latest codersdk.CryptoKey + for _, key := range keys { + m[key.Sequence] = key + if key.Sequence > latest.Sequence && key.CanSign(now) { + m[latestSequence] = key + } + } + return m +} + +func (k *CryptoKeyCache) Close() { + k.mu.Lock() + defer k.mu.Unlock() + + if k.closed { + return + } + + k.closed = true + k.refreshCancel() + k.refresher.Stop() + k.cond.Broadcast() +} diff --git a/enterprise/wsproxy/keycache_test.go b/enterprise/wsproxy/keycache_test.go new file mode 100644 index 0000000000000..210e04f9edf76 --- /dev/null +++ b/enterprise/wsproxy/keycache_test.go @@ -0,0 +1,485 @@ +package wsproxy_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + "github.com/coder/coder/v2/coderd/cryptokeys" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/wsproxy" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestCryptoKeyCache(t *testing.T) { + t.Parallel() + + t.Run("Signing", func(t *testing.T) { + t.Parallel() + + t.Run("HitsCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 2, + StartsAt: now, + } + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{expected}, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + }) + + t.Run("MissesCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: clock.Now().UTC(), + } + ff.keys = []codersdk.CryptoKey{expected} + + got, err := cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + // 1 on startup + missing cache. + require.Equal(t, 2, ff.called) + + // Ensure the cache gets hit this time. + got, err = cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + // 1 on startup + missing cache. + require.Equal(t, 2, ff.called) + }) + + t.Run("IgnoresInvalid", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 1, + StartsAt: clock.Now().UTC(), + } + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + { + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 2, + StartsAt: now.Add(-time.Second), + DeletesAt: now, + }, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + }) + + t.Run("KeyNotFound", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + _, err = cache.Signing(ctx) + require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) + }) + }) + + t.Run("Verifying", func(t *testing.T) { + t.Parallel() + + t.Run("HitsCache", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + } + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + { + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + }, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + }) + + t.Run("MissesCache", func(t *testing.T) { + t.Parallel() + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: clock.Now().UTC(), + } + ff.keys = []codersdk.CryptoKey{expected} + + got, err := cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 2, ff.called) + + // Ensure the cache gets hit this time. + got, err = cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 2, ff.called) + }) + + t.Run("AllowsBeforeStartsAt", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now.Add(-time.Second), + } + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + }) + + t.Run("KeyInvalid", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now.Add(-time.Second), + DeletesAt: now, + } + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + _, err = cache.Verifying(ctx, expected.Sequence) + require.ErrorIs(t, err, cryptokeys.ErrKeyInvalid) + require.Equal(t, 1, ff.called) + }) + + t.Run("KeyNotFound", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{}, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + _, err = cache.Verifying(ctx, 1) + require.ErrorIs(t, err, cryptokeys.ErrKeyNotFound) + }) + }) + + t.Run("CacheRefreshes", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + DeletesAt: now.Add(time.Minute * 10), + } + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + + newKey := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + } + ff.keys = []codersdk.CryptoKey{newKey} + + // The ticker should fire and cause a request to coderd. + dur, advance := clock.AdvanceNext() + advance.MustWait(ctx) + require.Equal(t, 2, ff.called) + require.Equal(t, time.Minute*10, dur) + + // Assert hits cache. + got, err = cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, newKey, got) + require.Equal(t, 2, ff.called) + + // We check again to ensure the timer has been reset. + _, advance = clock.AdvanceNext() + advance.MustWait(ctx) + require.Equal(t, 3, ff.called) + require.Equal(t, time.Minute*10, dur) + }) + + // This test ensures that if the refresh timer races with an inflight request + // and loses that it doesn't cause a redundant fetch. + + t.Run("RefreshNoDoubleFetch", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now().UTC() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + DeletesAt: now.Add(time.Minute * 10), + } + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } + + // Create a trap that blocks when the refresh timer fires. + trap := clock.Trap().Now("refresh") + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + _, wait := clock.AdvanceNext() + trapped := trap.MustWait(ctx) + + newKey := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key2", + Sequence: 13, + StartsAt: now, + } + ff.keys = []codersdk.CryptoKey{newKey} + + _, err = cache.Verifying(ctx, newKey.Sequence) + require.NoError(t, err) + require.Equal(t, 2, ff.called) + + trapped.Release() + wait.MustWait(ctx) + require.Equal(t, 2, ff.called) + trap.Close() + + // The next timer should fire in 10 minutes. + dur, wait := clock.AdvanceNext() + wait.MustWait(ctx) + require.Equal(t, time.Minute*10, dur) + require.Equal(t, 3, ff.called) + }) + + t.Run("Closed", func(t *testing.T) { + t.Parallel() + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, nil) + clock = quartz.NewMock(t) + ) + + now := clock.Now() + expected := codersdk.CryptoKey{ + Feature: codersdk.CryptoKeyFeatureWorkspaceApp, + Secret: "key1", + Sequence: 12, + StartsAt: now, + } + ff := &fakeFetcher{ + keys: []codersdk.CryptoKey{ + expected, + }, + } + + cache, err := wsproxy.NewCryptoKeyCache(ctx, logger, ff, withClock(clock)) + require.NoError(t, err) + + got, err := cache.Signing(ctx) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + + got, err = cache.Verifying(ctx, expected.Sequence) + require.NoError(t, err) + require.Equal(t, expected, got) + require.Equal(t, 1, ff.called) + + cache.Close() + + _, err = cache.Signing(ctx) + require.ErrorIs(t, err, cryptokeys.ErrClosed) + + _, err = cache.Verifying(ctx, expected.Sequence) + require.ErrorIs(t, err, cryptokeys.ErrClosed) + }) +} + +type fakeFetcher struct { + keys []codersdk.CryptoKey + called int +} + +func (f *fakeFetcher) Fetch(_ context.Context) ([]codersdk.CryptoKey, error) { + f.called++ + return f.keys, nil +} + +func withClock(clock quartz.Clock) func(*wsproxy.CryptoKeyCache) { + return func(cache *wsproxy.CryptoKeyCache) { + cache.Clock = clock + } +}