8000 fix: avoid writing messages after close and improve handshake by FrauElster · Pull Request #476 · coder/websocket · GitHub
[go: up one dir, main page]

Skip to content

fix: avoid writing messages after close and improve handshake #476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 4, 2024
Merged
Prev Previous commit
Next Next commit
fix: echo read error after close received
  • Loading branch information
mafredri committed Sep 17, 2024
commit c3613fc66341cffa20fc9eacb3017bec47a50f59
4 changes: 4 additions & 0 deletions close.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ func (c *Conn) waitCloseHandshake() error {
}
defer c.readMu.unlock()

if c.readCloseErr != nil {
return c.readCloseErr
}

for i := int64(0); i < c.msgReader.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type Conn struct {
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
readCloseErr error

// Write state.
msgWriter *msgWriter
Expand All @@ -70,6 +71,7 @@ type Conn struct {
writeHeader header
closeSent bool

// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
Expand Down
19 changes: 18 additions & 1 deletion read.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ func (c *Conn) readRSV1Illegal(h header) bool {
}

func (c *Conn) readLoop(ctx context.Context) (header, error) {
if c.readCloseErr != nil {
select {
case <-c.closed:
return header{}, net.ErrClosed
default:
}
return header{}, c.readCloseErr
}

for {
h, err := c.readFrameHeader(ctx)
if err != nil {
Expand Down Expand Up @@ -324,8 +333,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
return err
}

if c.readCloseErr == nil {
c.readCloseErr = ce
}

err = fmt.Errorf("received close frame: %w", ce)
c.writeClose(ce.Code, ce.Reason)
if err2 := c.writeClose(ce.Code, ce.Reason); errors.Is(err2, errCloseSent) {
// The close handshake has already been initiated, connection
// close should be handled elsewhere.
return err
}
c.readMu.unlock()
c.close()
return err
Expand Down
4 changes: 3 additions & 1 deletion write.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
return nil
}

var errCloseSent = errors.New("close sent")

// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
Expand All @@ -255,7 +257,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
return 0, net.ErrClosed
default:
}
return 0, errors.New("close sent")
return 0, errCloseSent
}

select {
Expand Down
Loading
0