@@ -2,122 +2,260 @@ package workspacetraffic
2
2
3
3
import (
4
4
"context"
5
+ "encoding/json"
6
+ "errors"
5
7
"io"
6
8
"sync"
9
+ "time"
7
10
8
11
"github.com/coder/coder/v2/codersdk"
9
12
10
13
"github.com/google/uuid"
11
- "github.com/hashicorp/go-multierror"
12
14
gossh "golang.org/x/crypto/ssh"
13
15
"golang.org/x/xerrors"
14
16
)
15
17
16
- func connectPTY (ctx context.Context , client * codersdk.Client , agentID , reconnect uuid.UUID ) (* countReadWriteCloser , error ) {
18
+ const (
19
+ // Set a timeout for graceful close of the connection.
20
+ connCloseTimeout = 30 * time .Second
21
+ // Set a timeout for waiting for the connection to close.
22
+ waitCloseTimeout = connCloseTimeout + 5 * time .Second
23
+
24
+ // In theory, we can send larger payloads to push bandwidth, but we need to
25
+ // be careful not to send too much data at once or the server will close the
26
+ // connection. We see this more readily as our JSON payloads approach 28KB.
27
+ //
28
+ // failed to write frame: WebSocket closed: received close frame: status = StatusMessageTooBig and reason = "read limited at 32769 bytes"
29
+ //
30
+ // Since we can't control fragmentation/buffer sizes, we keep it simple and
31
+ // match the conservative payload size used by agent/reconnectingpty (1024).
32
8000
+ rptyJSONMaxDataSize = 1024
33
+ )
34
+
35
+ func connectRPTY (ctx context.Context , client * codersdk.Client , agentID , reconnect uuid.UUID , cmd string ) (* countReadWriteCloser , error ) {
36
+ width , height := 80 , 25
17
37
conn , err := client .WorkspaceAgentReconnectingPTY (ctx , codersdk.WorkspaceAgentReconnectingPTYOpts {
18
38
AgentID : agentID ,
19
39
Reconnect : reconnect ,
20
- Height : 25 ,
21
- Width : 80 ,
22
- Command : "sh" ,
40
+ Width : uint16 ( width ) ,
41
+ Height : uint16 ( height ) ,
42
+ Command : cmd ,
23
43
})
24
44
if err != nil {
25
45
return nil , xerrors .Errorf ("connect pty: %w" , err )
26
46
}
27
47
28
48
// Wrap the conn in a countReadWriteCloser so we can monitor bytes sent/rcvd.
29
- crw := countReadWriteCloser {ctx : ctx , rwc : conn }
49
+ crw := countReadWriteCloser {rwc : newPTYConn ( conn ) }
30
50
return & crw , nil
31
51
}
32
52
33
- func connectSSH (ctx context.Context , client * codersdk.Client , agentID uuid.UUID ) (* countReadWriteCloser , error ) {
53
+ type rptyConn struct {
54
+ conn io.ReadWriteCloser
55
+ wenc * json.Encoder
56
+
57
+ mu sync.Mutex
58
+ closed bool
59
+ readErr chan error
60
+ readOnce sync.Once
61
+ }
62
+
63
+ func newPTYConn (conn io.ReadWriteCloser ) * rptyConn {
64
+ rc := & rptyConn {
65
+ conn : conn ,
66
+ wenc : json .NewEncoder (conn ),
67
+ readErr : make (chan error , 1 ),
68
+ }
69
+ return rc
70
+ }
71
+
72
+ func (c * rptyConn ) Read (p []byte ) (int , error ) {
73
+ n , err := c .conn .Read (p )
74
+ if err != nil {
75
+ c .readOnce .Do (func () {
76
+ c .readErr <- err
77
+ close (c .readErr )
78
+ })
79
+ return n , err
80
+ }
81
+ return n , nil
82
+ }
83
+
84
+ func (c * rptyConn ) Write (p []byte ) (int , error ) {
85
+ c .mu .Lock ()
86
+ defer c .mu .Unlock ()
87
+
88
+ // Early exit in case we're closing, this is to let call write Ctrl+C
89
+ // without a flood of other writes.
90
+ if c .closed {
91
+ return 0 , io .EOF
92
+ }
93
+
94
+ return c .writeNoLock (p )
95
+ }
96
+
97
+ func (c * rptyConn ) writeNoLock (p []byte ) (n int , err error ) {
98
+ // If we try to send more than the max payload size, the server will close the connection.
99
+ for len (p ) > 0 {
100
+ pp := p
101
+ if len (pp ) > rptyJSONMaxDataSize {
102
+ pp = p [:rptyJSONMaxDataSize ]
103
+ }
104
+ p = p [len (pp ):]
105
+ req := codersdk.ReconnectingPTYRequest {Data : string (pp )}
106
+ if err := c .wenc .Encode (req ); err != nil {
107
+ return n , xerrors .Errorf ("encode pty request: %w" , err )
108
+ }
109
+ n += len (pp )
110
+ }
111
+ return n , nil
112
+ }
113
+
114
+ func (c * rptyConn ) Close () (err error ) {
115
+ c .mu .Lock ()
116
+ if c .closed {
117
+ c .mu .Unlock ()
118
+ return nil
119
+ }
120
+ c .closed = true
121
+ c .mu .Unlock ()
122
+
123
+ defer c .conn .Close ()
124
+
125
+ // Send Ctrl+C to interrupt the command.
126
+ _ , err = c .writeNoLock ([]byte ("\u0003 " ))
127
+ if err != nil {
128
+ return xerrors .Errorf ("write ctrl+c: %w" , err )
129
+ }
130
+ select {
131
+ case <- time .After (connCloseTimeout ):
132
+ return xerrors .Errorf ("timeout waiting for read to finish" )
133
+ case err = <- c .readErr :
134
+ if errors .Is (err , io .EOF ) {
135
+ return nil
136
+ }
137
+ return err
138
+ }
139
+ }
140
+
141
+ //nolint:revive // Ignore requestPTY control flag.
142
+ func connectSSH (ctx context.Context , client * codersdk.Client , agentID uuid.UUID , cmd string , requestPTY bool ) (rwc * countReadWriteCloser , err error ) {
143
+ var closers []func () error
144
+ defer func () {
145
+ if err != nil {
146
+ for _ , c := range closers {
147
+ if err2 := c (); err2 != nil {
148
+ err = errors .Join (err , err2 )
149
+ }
150
+ }
151
+ }
152
+ }()
153
+
34
154
agentConn , err := client .DialWorkspaceAgent (ctx , agentID , & codersdk.DialWorkspaceAgentOptions {})
35
155
if err != nil {
36
156
return nil , xerrors .Errorf ("dial workspace agent: %w" , err )
37
157
}
38
- agentConn .AwaitReachable (ctx )
158
+ closers = append (closers , agentConn .Close )
159
+
39
160
sshClient , err := agentConn .SSHClient (ctx )
40
161
if err != nil {
41
162
return nil , xerrors .Errorf ("get ssh client: %w" , err )
42
163
}
164
+ closers = append (closers , sshClient .Close )
165
+
43
166
sshSession , err := sshClient .NewSession ()
44
167
if err != nil {
45
- _ = agentConn .Close ()
46
168
return nil , xerrors .Errorf ("new ssh session: %w" , err )
47
169
}
48
- wrappedConn := & wrappedSSHConn {ctx : ctx }
170
+ closers = append (closers , sshSession .Close )
171
+
172
+ wrappedConn := & wrappedSSHConn {}
173
+
49
174
// Do some plumbing to hook up the wrappedConn
50
175
pr1 , pw1 := io .Pipe ()
176
+ closers = append (closers , pr1 .Close , pw1 .Close )
51
177
wrappedConn .stdout = pr1
52
178
sshSession .Stdout = pw1
179
+
53
180
pr2 , pw2 := io .Pipe ()
181
+ closers = append (closers , pr2 .Close , pw2 .Close )
54
182
sshSession .Stdin = pr2
55
183
wrappedConn .stdin = pw2
56
- err = sshSession .RequestPty ("xterm" , 25 , 80 , gossh.TerminalModes {})
57
- if err != nil {
58
- _ = pr1 .Close ()
59
- _ = pr2 .Close ()
60
- _ = pw1 .Close ()
61
- _ = pw2 .Close ()
62
- _ = sshSession .Close ()
63
- _ = agentConn .Close ()
64
- return nil , xerrors .Errorf ("request pty: %w" , err )
184
+
185
+ if requestPTY {
186
+ err = sshSession .RequestPty ("xterm" , 25 , 80 , gossh.TerminalModes {})
187
+ if err != nil {
188
+ return nil , xerrors .Errorf ("request pty: %w" , err )
189
+ }
65
190
}
66
- err = sshSession .Shell ( )
191
+ err = sshSession .Start ( cmd )
67
192
if err != nil {
68
- _ = sshSession .Close ()
69
- _ = agentConn .Close ()
70
193
return nil , xerrors .Errorf ("shell: %w" , err )
71
194
}
195
+ waitErr := make (chan error , 1 )
196
+ go func () {
197
+ waitErr <- sshSession .Wait ()
198
+ }()
72
199
73
200
closeFn := func () error {
74
- var merr error
75
- if err := sshSession .Close (); err != nil {
76
- merr = multierror .Append (merr , err )
201
+ // Start by closing stdin so we stop writing to the ssh session.
202
+ merr := pw2 .Close ()
203
+ if err := sshSession .Signal (gossh .SIGHUP ); err != nil {
204
+ merr = errors .Join (merr , err )
77
205
}
78
- if err := agentConn .Close (); err != nil {
79
- merr = multierror .Append (merr , err )
206
+ select {
207
+ case <- time .After (connCloseTimeout ):
208
+ merr = errors .Join (merr , xerrors .Errorf ("timeout waiting for ssh session to close" ))
209
+ case err := <- waitErr :
210
+ if err != nil {
211
+ var exitErr * gossh.ExitError
212
+ if xerrors .As (err , & exitErr ) {
213
+ // The exit status is 255 when the command is
214
+ // interrupted by a signal. This is expected.
215
+ if exitErr .ExitStatus () != 255 {
216
+ merr = errors .Join (merr , xerrors .Errorf ("ssh session exited with unexpected status: %d" , int32 (exitErr .ExitStatus ())))
217
+ }
218
+ } else {
219
+ merr = errors .Join (merr , err )
220
+ }
221
+ }
222
+ }
223
+ for _ , c := range closers {
224
+ if err := c (); err != nil {
225
+ if ! errors .Is (err , io .EOF ) {
226
+ merr = errors .Join (merr , err )
227
+ }
228
+ }
80
229
}
81
230
return merr
82
231
}
83
232
wrappedConn .close = closeFn
84
233
85
- crw := & countReadWriteCloser {ctx : ctx , rwc : wrappedConn }
234
+ crw := & countReadWriteCloser {rwc : wrappedConn }
235
+
86
236
return crw , nil
87
237
}
88
238
89
239
// wrappedSSHConn wraps an ssh.Session to implement io.ReadWriteCloser.
90
240
type wrappedSSHConn struct {
91
- ctx context.Context
92
241
stdout io.Reader
93
- stdin io.Writer
242
+ stdin io.WriteCloser
94
243
closeOnce sync.Once
95
244
closeErr error
96
245
close func () error
97
246
}
98
247
99
248
func (w * wrappedSSHConn ) Close () error {
100
249
w .closeOnce .Do (func () {
101
- _ , _ = w .stdin .Write ([]byte ("exit\n " ))
102
250
w .closeErr = w .close ()
103
251
})
104
252
return w .closeErr
105
253
}
106
254
107
255
func (w * wrappedSSHConn ) Read (p []byte ) (n int , err error ) {
108
- select {
109
- case <- w .ctx .Done ():
110
- return 0 , xerrors .Errorf ("read: %w" , w .ctx .Err ())
111
- default :
112
- return w .stdout .Read (p )
113
- }
256
+ return w .stdout .Read (p )
114
257
}
115
258
116
259
func (w * wrappedSSHConn ) Write (p []byte ) (n int , err error ) {
117
- select {
118
- case <- w .ctx .Done ():
119
- return 0 , xerrors .Errorf ("write: %w" , w .ctx .Err ())
120
- default :
121
- return w .stdin .Write (p )
122
- }
260
+ return w .stdin .Write (p )
123
261
}
0 commit comments