8000 Ensure connection is closed at all error points by nhooyr · Pull Request #193 · coder/websocket · GitHub
[go: up one dir, main page]

Skip to content

Ensure connection is closed at all error points #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
8000
Diff view
Next Next commit
Ensure connection is closed at all error points
Closes #191
  • Loading branch information
nhooyr committed Feb 20, 2020
commit 1200707bd313a46fda2fb828ff840dac4e12283f
26 changes: 12 additions & 14 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
defer c.readMu.unlock()

if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
err = errors.New("previous message not read to completion")
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
}

h, err := c.readLoop(ctx)
Expand Down Expand Up @@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) {
}

func (mr *msgReader) Read(p []byte) (n int, err error) {
defer func() {
if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
err = io.EOF
}
if errors.Is(err, io.EOF) {
err = io.EOF
mr.putFlateReader()
return
}
errd.Wrap(&err, "failed to read")
}()

err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()

Expand All @@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.close(err)
}
return n, err
}

Expand Down
19 changes: 15 additions & 4 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,16 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
defer errd.Wrap(&err, "failed to write")

mw.writeMu.Lock()
defer mw.writeMu.Unlock()

defer func() {
err = fmt.Errorf("failed to write: %w", err)
if err != nil {
mw.c.close(err)
}
}()

if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
Expand Down Expand Up @@ -230,8 +235,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
}

// frame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) {
err := c.writeFrameMu.lock(ctx)
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
Expand All @@ -243,6 +248,12 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
case c.writeTimeout <- ctx:
}

defer func() {
if err != nil {
c.close(fmt.Errorf("failed to write frame: %w", err))
}
}()

c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
Expand Down
0