8000 Use a limit writer to capture actual tar size · coder/coder@0cd070c · GitHub
[go: up one dir, main page]

Skip to content

Commit 0cd070c

Browse files
committed
Use a limit writer to capture actual tar size
1 parent b34e9d8 commit 0cd070c

File tree

3 files changed

+189
-11
lines changed

3 files changed

+189
-11
lines changed

coderd/util/xio/limitwriter.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package xio
2+
3+
import (
4+
"io"
5+
6+
"golang.org/x/xerrors"
7+
)
8+
9+
var ErrLimitReached = xerrors.Errorf("i/o limit reached")
10+
11+
// LimitWriter will only write bytes to the underlying writer until the limit is reached.
12+
type LimitWriter struct {
13+
Limit int64
14+
N int64
15+
W io.Writer
16+
}
17+
18+
func NewLimitWriter(w io.Writer, n int64) *LimitWriter {
19+
// If anyone tries this, just make a 0 writer.
20+
if n < 0 {
21+
n = 0
22+
}
23+
return &LimitWriter{
24+
Limit: n,
25+
N: 0,
26+
W: w,
27+
}
28+
}
29+
30+
func (l *LimitWriter) Write(p []byte) (int, error) {
31+
if l.N >= l.Limit {
32+
return 0, ErrLimitReached
33+
}
34+
35+
// Write 0 bytes if the limit is to be exceeded.
36+
if int64(len(p)) > l.Limit-l.N {
37+
return 0, ErrLimitReached
38+
}
39+
40+
n, err := l.W.Write(p)
41+
l.N += int64(n)
42+
return n, err
43+
}

coderd/util/xio/limitwriter_test.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package xio_test
2+
3+
import (
4+
"bytes"
5+
cryptorand "crypto/rand"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/coderd/util/xio"
11+
)
12+
13+
func TestLimitWriter(t *testing.T) {
14+
type writeCase struct {
15+
N int
16+
ExpN int
17+
Err bool
18+
}
19+
20+
// testCases will do multiple writes to the same limit writer and check the output.
21+
testCases := []struct {
22+
Name string
23+
L int64
24+
Writes []writeCase
25+
N int
26+
ExpN int
27+
}{
28+
{
29+
Name: "Empty",
30+
L: 1000,
31+
Writes: []writeCase{
32+
// A few empty writes
33+
{N: 0, ExpN: 0}, {N: 0, ExpN: 0}, {N: 0, ExpN: 0},
34+
},
35+
},
36+
{
37+
Name: "NotFull",
38+
L: 1000,
39+
Writes: []writeCase{
40+
{N: 250, ExpN: 250},
41+
{N: 250, ExpN: 250},
42+
{N: 250, ExpN: 250},
43+
},
44+
},
45+
{
46+
Name: "Short",
47+
L: 1000,
48+
Writes: []writeCase{
49+
{N: 250, ExpN: 250},
50+
{N: 250, ExpN: 250},
51+
{N: 250, ExpN: 250},
52+
{N: 250, ExpN: 250},
53+
{N: 250, ExpN: 0, Err: true},
54+
},
55+
},
56+
{
57+
Name: "Exact",
58+
L: 1000,
59+
Writes: []writeCase{
60+
{
61+
N: 1000,
62+
ExpN: 1000,
63+
},
64+
{
65+
N: 1000,
66+
Err: true,
67+
},
68+
},
69+
},
70+
{
71+
Name: "Over",
72+
L: 1000,
73+
Writes: []writeCase{
74+
{
75+
N: 5000,
76+
ExpN: 0,
77+
Err: true,
78+
},
79+
{
80+
N: 5000,
81+
Err: true,
82+
},
83+
{
84+
N: 5000,
85+
Err: true,
86+
},
87+
},
88+
},
89+
{
90+
Name: "Strange",
91+
L: -1,
92+
Writes: []writeCase{
93+
{
94+
N: 5,
95+
ExpN: 0,
96+
Err: true,
97+
},
98+
{
99+
N: 0,
100+
ExpN: 0,
101+
Err: true,
102+
},
103+
},
104+
},
105+
}
106+
107+
for _, c := range testCases {
108+
t.Run(c.Name, func(t *testing.T) {
109+
buf := bytes.NewBuffer([]byte{})
110+
allBuff := bytes.NewBuffer([]byte{})
111+
w := xio.NewLimitWriter(buf, c.L)
112+
113+
for _, wc := range c.Writes {
114+
data := make([]byte, wc.N)
115+
116+
n, err := cryptorand.Read(data)
117+
require.NoError(t, err, "crand read")
118+
require.Equal(t, wc.N, n, "correct bytes read")
119+
max := data[:wc.ExpN]
120+
n, err = w.Write(data)
121+
if wc.Err {
122+
require.Error(t, err, "exp error")
123+
} else {
124+
require.NoError(t, err, "write")
125+
}
126+
127+
// Need to use this to compare across multiple writes.
128+
// Each write appends to the expected output.
129+
allBuff.Write(max)
130+
131+
require.Equal(t, wc.ExpN, n, "correct bytes written")
132+
require.Equal(t, allBuff.Bytes(), buf.Bytes(), "expected data")
133+
}
134+
})
135+
}
136+
}

provisionersdk/archive.go

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"strings"
99

1010
"golang.org/x/xerrors"
11+
12+
"github.com/coder/coder/coderd/util/xio"
1113
)
1214

1315
const (
@@ -32,8 +34,9 @@ func dirHasExt(dir string, ext string) (bool, error) {
3234

3335
// Tar archives a Terraform directory.
3436
func Tar(w io.Writer, directory string, limit int64) error {
37+
// The total bytes written must be under the limit.
38+
w = xio.NewLimitWriter(w, limit)
3539
tarWriter := tar.NewWriter(w)
36-
totalSize := int64(0)
3740

3841
const tfExt = ".tf"
3942
hasTf, err := dirHasExt(directory, tfExt)
@@ -54,7 +57,6 @@ func Tar(w io.Writer, directory string, limit int64) error {
5457
)
5558
}
5659

57-
fileTooBigError := xerrors.Errorf("Archive too big. Must be <= %d bytes", limit)
5860
err = filepath.Walk(directory, func(file string, fileInfo os.FileInfo, err error) error {
5961
if err != nil {
6062
return err
@@ -96,23 +98,20 @@ func Tar(w io.Writer, directory string, limit int64) error {
9698
if !fileInfo.Mode().IsRegular() {
9799
return nil
98100
}
99-
// Before we even open the file, check if it is going to exceed our limit.
100-
if fileInfo.Size()+totalSize > limit {
101-
return fileTooBigError
102-
}
101+
103102
data, err := os.Open(file)
104103
if err != nil {
105104
return err
106105
}
107106
defer data.Close()
108-
wrote, err := io.Copy(tarWriter, data)
107+
_, err = io.Copy(tarWriter, data)
109108
if err != nil {
109+
if xerrors.Is(err, xio.ErrLimitReached) {
110+
return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit)
111+
}
110112
return err
111113
}
112-
totalSize += wrote
113-
if limit != 0 && totalSize > limit {
114-
return fileTooBigError
115-
}
114+
116115
return data.Close()
117116
})
118117
if err != nil {

0 commit comments

Comments
 (0)
0