diff --git a/archive/fs/tar.go b/archive/fs/tar.go new file mode 100644 index 0000000000000..ab4027d5445ee --- /dev/null +++ b/archive/fs/tar.go @@ -0,0 +1,17 @@ +package archivefs + +import ( + "archive/tar" + "io" + "io/fs" + + "github.com/spf13/afero" + "github.com/spf13/afero/tarfs" +) + +func FromTarReader(r io.Reader) fs.FS { + tr := tar.NewReader(r) + tfs := tarfs.New(tr) + rofs := afero.NewReadOnlyFs(tfs) + return afero.NewIOFS(rofs) +} diff --git a/coderd/files/cache.go b/coderd/files/cache.go new file mode 100644 index 0000000000000..b823680fa7245 --- /dev/null +++ b/coderd/files/cache.go @@ -0,0 +1,110 @@ +package files + +import ( + "bytes" + "context" + "io/fs" + "sync" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + archivefs "github.com/coder/coder/v2/archive/fs" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/util/lazy" +) + +// NewFromStore returns a file cache that will fetch files from the provided +// database. +func NewFromStore(store database.Store) Cache { + fetcher := func(ctx context.Context, fileID uuid.UUID) (fs.FS, error) { + file, err := store.GetFileByID(ctx, fileID) + if err != nil { + return nil, xerrors.Errorf("failed to read file from database: %w", err) + } + + content := bytes.NewBuffer(file.Data) + return archivefs.FromTarReader(content), nil + } + + return Cache{ + lock: sync.Mutex{}, + data: make(map[uuid.UUID]*cacheEntry), + fetcher: fetcher, + } +} + +// Cache persists the files for template versions, and is used by dynamic +// parameters to deduplicate the files in memory. When any number of users opens +// the workspace creation form for a given template version, it's files are +// loaded into memory exactly once. We hold those files until there are no +// longer any open connections, and then we remove the value from the map. +type Cache struct { + lock sync.Mutex + data map[uuid.UUID]*cacheEntry + fetcher +} + +type cacheEntry struct { + // refCount must only be accessed while the Cache lock is held. + refCount int + value *lazy.ValueWithError[fs.FS] +} + +type fetcher func(context.Context, uuid.UUID) (fs.FS, error) + +// Acquire will load the fs.FS for the given file. It guarantees that parallel +// calls for the same fileID will only result in one fetch, and that parallel +// calls for distinct fileIDs will fetch in parallel. +// +// Every call to Acquire must have a matching call to Release. +func (c *Cache) Acquire(ctx context.Context, fileID uuid.UUID) (fs.FS, error) { + // It's important that this `Load` call occurs outside of `prepare`, after the + // mutex has been released, or we would continue to hold the lock until the + // entire file has been fetched, which may be slow, and would prevent other + // files from being fetched in parallel. + return c.prepare(ctx, fileID).Load() +} + +func (c *Cache) prepare(ctx context.Context, fileID uuid.UUID) *lazy.ValueWithError[fs.FS] { + c.lock.Lock() + defer c.lock.Unlock() + + entry, ok := c.data[fileID] + if !ok { + value := lazy.NewWithError(func() (fs.FS, error) { + return c.fetcher(ctx, fileID) + }) + + entry = &cacheEntry{ + value: value, + refCount: 0, + } + c.data[fileID] = entry + } + + entry.refCount++ + return entry.value +} + +// Release decrements the reference count for the given fileID, and frees the +// backing data if there are no further references being held. +func (c *Cache) Release(fileID uuid.UUID) { + c.lock.Lock() + defer c.lock.Unlock() + + entry, ok := c.data[fileID] + if !ok { + // If we land here, it's almost certainly because a bug already happened, + // and we're freeing something that's already been freed, or we're calling + // this function with an incorrect ID. Should this function return an error? + return + } + + entry.refCount-- + if entry.refCount > 0 { + return + } + + delete(c.data, fileID) +} diff --git a/coderd/files/cache_internal_test.go b/coderd/files/cache_internal_test.go new file mode 100644 index 0000000000000..03603906b6ccd --- /dev/null +++ b/coderd/files/cache_internal_test.go @@ -0,0 +1,104 @@ +package files + +import ( + "context" + "io/fs" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/spf13/afero" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/coder/coder/v2/testutil" +) + +func TestConcurrency(t *testing.T) { + t.Parallel() + + emptyFS := afero.NewIOFS(afero.NewReadOnlyFs(afero.NewMemMapFs())) + var fetches atomic.Int64 + c := newTestCache(func(_ context.Context, _ uuid.UUID) (fs.FS, error) { + fetches.Add(1) + // Wait long enough before returning to make sure that all of the goroutines + // will be waiting in line, ensuring that no one duplicated a fetch. + time.Sleep(testutil.IntervalMedium) + return emptyFS, nil + }) + + batches := 1000 + groups := make([]*errgroup.Group, 0, batches) + for range batches { + groups = append(groups, new(errgroup.Group)) + } + + // Call Acquire with a unique ID per batch, many times per batch, with many + // batches all in parallel. This is pretty much the worst-case scenario: + // thousands of concurrent reads, with both warm and cold loads happening. + batchSize := 10 + for _, g := range groups { + id := uuid.New() + for range batchSize { + g.Go(func() error { + // We don't bother to Release these references because the Cache will be + // released at the end of the test anyway. + _, err := c.Acquire(t.Context(), id) + return err + }) + } + } + + for _, g := range groups { + require.NoError(t, g.Wait()) + } + require.Equal(t, int64(batches), fetches.Load()) +} + +func TestRelease(t *testing.T) { + t.Parallel() + + emptyFS := afero.NewIOFS(afero.NewReadOnlyFs(afero.NewMemMapFs())) + c := newTestCache(func(_ context.Context, _ uuid.UUID) (fs.FS, error) { + return emptyFS, nil + }) + + batches := 100 + ids := make([]uuid.UUID, 0, batches) + for range batches { + ids = append(ids, uuid.New()) + } + + // Acquire a bunch of references + batchSize := 10 + for _, id := range ids { + for range batchSize { + it, err := c.Acquire(t.Context(), id) + require.NoError(t, err) + require.Equal(t, emptyFS, it) + } + } + + // Make sure cache is fully loaded + require.Equal(t, len(c.data), batches) + + // Now release all of the references + for _, id := range ids { + for range batchSize { + c.Release(id) + } + } + + // ...and make sure that the cache has emptied itself. + require.Equal(t, len(c.data), 0) +} + +func newTestCache(fetcher func(context.Context, uuid.UUID) (fs.FS, error)) Cache { + return Cache{ + lock: sync.Mutex{}, + data: make(map[uuid.UUID]*cacheEntry), + fetcher: fetcher, + } +} diff --git a/coderd/util/lazy/valuewitherror.go b/coderd/util/lazy/valuewitherror.go new file mode 100644 index 0000000000000..acc9a370eea23 --- /dev/null +++ b/coderd/util/lazy/valuewitherror.go @@ -0,0 +1,25 @@ +package lazy + +type ValueWithError[T any] struct { + inner Value[result[T]] +} + +type result[T any] struct { + value T + err error +} + +// NewWithError allows you to provide a lazy initializer that can fail. +func NewWithError[T any](fn func() (T, error)) *ValueWithError[T] { + return &ValueWithError[T]{ + inner: Value[result[T]]{fn: func() result[T] { + value, err := fn() + return result[T]{value: value, err: err} + }}, + } +} + +func (v *ValueWithError[T]) Load() (T, error) { + result := v.inner.Load() + return result.value, result.err +} diff --git a/coderd/util/lazy/valuewitherror_test.go b/coderd/util/lazy/valuewitherror_test.go new file mode 100644 index 0000000000000..4949c57a6f2ac --- /dev/null +++ b/coderd/util/lazy/valuewitherror_test.go @@ -0,0 +1,52 @@ +package lazy_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/util/lazy" +) + +func TestLazyWithErrorOK(t *testing.T) { + t.Parallel() + + l := lazy.NewWithError(func() (int, error) { + return 1, nil + }) + + i, err := l.Load() + require.NoError(t, err) + require.Equal(t, 1, i) +} + +func TestLazyWithErrorErr(t *testing.T) { + t.Parallel() + + l := lazy.NewWithError(func() (int, error) { + return 0, xerrors.New("oh no! everything that could went horribly wrong!") + }) + + i, err := l.Load() + require.Error(t, err) + require.Equal(t, 0, i) +} + +func TestLazyWithErrorPointers(t *testing.T) { + t.Parallel() + + a := 1 + l := lazy.NewWithError(func() (*int, error) { + return &a, nil + }) + + b, err := l.Load() + require.NoError(t, err) + c, err := l.Load() + require.NoError(t, err) + + *b++ + *c++ + require.Equal(t, 3, a) +}