diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml deleted file mode 100644 index fb83c3a9..00000000 --- a/.github/FUNDING.yml +++ /dev/null @@ -1 +0,0 @@ -github: nhooyr diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..fb0a4558 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,24 @@ +version: 2 +updates: + # Track in case we ever add dependencies. + - package-ecosystem: 'gomod' + directory: '/' + schedule: + interval: 'weekly' + commit-message: + prefix: 'chore' + + # Keep example and test/benchmark deps up-to-date. + - package-ecosystem: 'gomod' + directories: + - '/internal/examples' + - '/internal/thirdparty' + schedule: + interval: 'monthly' + commit-message: + prefix: 'chore' + labels: [] + groups: + internal-deps: + patterns: + - '*' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c650580..836381ef 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,11 @@ name: ci -on: [push, pull_request] +on: + push: + branches: + - master + pull_request: + branches: + - master concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} cancel-in-progress: true @@ -9,30 +15,36 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - - run: ./ci/fmt.sh + - run: make fmt lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - run: go version - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - - run: ./ci/lint.sh + - run: make lint test: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - - run: ./ci/test.sh - - uses: actions/upload-artifact@v3 + - run: make test + - uses: actions/upload-artifact@v4 with: name: coverage.html path: ./ci/out/coverage.html @@ -41,7 +53,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - - run: ./ci/bench.sh + - run: make bench diff --git a/.github/workflows/daily.yml b/.github/workflows/daily.yml index b1e64fbc..62e3d337 100644 --- a/.github/workflows/daily.yml +++ b/.github/workflows/daily.yml @@ -12,19 +12,25 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - - run: AUTOBAHN=1 ./ci/bench.sh + - run: AUTOBAHN=1 make bench test: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - - run: AUTOBAHN=1 ./ci/test.sh - - uses: actions/upload-artifact@v3 + - run: AUTOBAHN=1 make test + - uses: actions/upload-artifact@v4 with: name: coverage.html path: ./ci/out/coverage.html @@ -34,21 +40,27 @@ jobs: - uses: actions/checkout@v4 with: ref: dev - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - - run: AUTOBAHN=1 ./ci/bench.sh + - run: AUTOBAHN=1 make bench test-dev: runs-on: ubuntu-latest steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns - uses: actions/checkout@v4 with: ref: dev - - uses: actions/setup-go@v4 + - uses: actions/setup-go@v5 with: go-version-file: ./go.mod - - run: AUTOBAHN=1 ./ci/test.sh - - uses: actions/upload-artifact@v3 + - run: AUTOBAHN=1 make test + - uses: actions/upload-artifact@v4 with: - name: coverage.html + name: coverage-dev.html path: ./ci/out/coverage.html diff --git a/.github/workflows/static.yml b/.github/workflows/static.yml new file mode 100644 index 00000000..a78ce1b9 --- /dev/null +++ b/.github/workflows/static.yml @@ -0,0 +1,52 @@ +name: static + +on: + push: + branches: ['master'] + workflow_dispatch: + +# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages. +permissions: + contents: read + pages: write + id-token: write + +concurrency: + group: pages + cancel-in-progress: true + +jobs: + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + steps: + - name: Disable AppArmor + if: runner.os == 'Linux' + run: | + # Disable AppArmor for Ubuntu 23.10+. + # https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md + echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod + - name: Generate coverage and badge + run: | + make test + mkdir -p ./ci/out/static + cp ./ci/out/coverage.html ./ci/out/static/coverage.html + percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%') + wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success" + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: ./ci/out/static/ + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/LICENSE.txt b/LICENSE.txt index 77b5bef6..7e79329f 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2023 Anmol Sethi +Copyright (c) 2025 Coder Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..a3e4a20d --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +.PHONY: all +all: fmt lint test + +.PHONY: fmt +fmt: + ./ci/fmt.sh + +.PHONY: lint +lint: + ./ci/lint.sh + +.PHONY: test +test: + ./ci/test.sh + +.PHONY: bench +bench: + ./ci/bench.sh \ No newline at end of file diff --git a/README.md b/README.md index 1c5751d8..80d2b3cc 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,36 @@ # websocket -[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) -[![coverage](https://img.shields.io/badge/coverage-91%25-success)](https://nhooyr.io/websocket/coverage.html) +[![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket) +[![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html) websocket is a minimal and idiomatic WebSocket library for Go. ## Install ```sh -go get nhooyr.io/websocket +go get github.com/coder/websocket ``` +> [!NOTE] +> Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket). +> We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from +> 2019 to 2024. + ## Highlights - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) -- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) -- JSON helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage +- [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports) +- JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage - Zero alloc reads and writes - Concurrent writes -- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) -- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper -- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API +- [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close) +- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper +- [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression -- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections -- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) +- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections +- Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm) ## Roadmap @@ -58,7 +63,9 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { } defer c.CloseNow() - ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + // Set the context as needed. Use of r.Context() is not recommended + // to avoid surprising behavior (see http.Hijacker). + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() var v interface{} @@ -102,12 +109,14 @@ Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): - Mature and widely used - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) +- No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection. + - Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411) -Advantages of nhooyr.io/websocket: +Advantages of github.com/coder/websocket: - Minimal and idiomatic API - - Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. -- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper + - Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. +- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) @@ -115,24 +124,24 @@ Advantages of nhooyr.io/websocket: - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) -- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API +- Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API - Gorilla requires registering a pong callback before sending a Ping - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) -- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage +- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326) - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode -- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) +- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) #### golang.org/x/net/websocket [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). -The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning -to nhooyr.io/websocket. +The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning +to github.com/coder/websocket. #### gobwas/ws @@ -141,7 +150,7 @@ in an event driven style for performance. See the author's [blog post](https://m However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws -When writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. +When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use. #### lesismal/nbio @@ -150,4 +159,4 @@ event driven for performance reasons. However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio -When writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. +When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use. diff --git a/accept.go b/accept.go index 285b3103..f45fdd0b 100644 --- a/accept.go +++ b/accept.go @@ -5,6 +5,7 @@ package websocket import ( "bytes" + "context" "crypto/sha1" "encoding/base64" "errors" @@ -14,10 +15,10 @@ import ( "net/http" "net/textproto" "net/url" - "path/filepath" + "path" "strings" - "nhooyr.io/websocket/internal/errd" + "github.com/coder/websocket/internal/errd" ) // AcceptOptions represents Accept's options. @@ -41,8 +42,8 @@ type AcceptOptions struct { // One would set this field to []string{"example.com"} to authorize example.com to connect. // // Each pattern is matched case insensitively against the request origin host - // with filepath.Match. - // See https://golang.org/pkg/path/filepath/#Match + // with path.Match. + // See https://golang.org/pkg/path/#Match // // Please ensure you understand the ramifications of enabling this. // If used incorrectly your WebSocket server will be open to CSRF attacks. @@ -62,6 +63,22 @@ type AcceptOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { @@ -79,6 +96,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // // Accept will write a response to w on all errors. +// +// Note that using the http.Request Context after Accept returns may lead to +// unexpected behavior (see http.Hijacker). func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { return accept(w, r, opts) } @@ -96,7 +116,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con if !opts.InsecureSkipVerify { err = authenticateOrigin(r, opts.OriginPatterns) if err != nil { - if errors.Is(err, filepath.ErrBadPattern) { + if errors.Is(err, path.ErrBadPattern) { log.Printf("websocket: %v", err) err = errors.New(http.StatusText(http.StatusForbidden)) } @@ -105,7 +125,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } } - hj, ok := w.(http.Hijacker) + hj, ok := hijacker(w) if !ok { err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) @@ -153,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con client: false, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: brw.Reader, bw: brw.Writer, @@ -221,7 +243,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { for _, hostPattern := range originHosts { matched, err := match(hostPattern, u.Host) if err != nil { - return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) + return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err) } if matched { return nil @@ -234,7 +256,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { } func match(pattern, s string) (bool, error) { - return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) + return path.Match(strings.ToLower(pattern), strings.ToLower(s)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { diff --git a/accept_test.go b/accept_test.go index 7cb85d0f..3b45ac5c 100644 --- a/accept_test.go +++ b/accept_test.go @@ -10,10 +10,11 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "testing" - "nhooyr.io/websocket/internal/test/assert" - "nhooyr.io/websocket/internal/test/xrand" + "github.com/coder/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/xrand" ) func TestAccept(t *testing.T) { @@ -142,6 +143,69 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.Contains(t, err, `failed to hijack connection`) }) + + t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + _, err := Accept(w, r, nil) + assert.Contains(t, err, "failed to hijack connection") + }) + + t.Run("closeRace", func(t *testing.T) { + t.Parallel() + + server, _ := net.Pipe() + + rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server)) + newResponseWriter := func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (net.Conn, *bufio.ReadWriter, error) { + return server, rw, nil + }, + } + } + w := newResponseWriter() + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16)) + + c, err := Accept(w, r, nil) + wg := &sync.WaitGroup{} + wg.Add(2) + go func() { + c.Close(StatusInternalError, "the sky is falling") + wg.Done() + }() + go func() { + c.CloseNow() + wg.Done() + }() + wg.Wait() + assert.Success(t, err) + }) } func Test_verifyClientHandshake(t *testing.T) { @@ -497,3 +561,14 @@ var _ http.Hijacker = mockHijacker{} func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return mj.hijack() } + +type mockUnwrapper struct { + http.ResponseWriter + unwrap func() http.ResponseWriter +} + +var _ rwUnwrapper = mockUnwrapper{} + +func (mu mockUnwrapper) Unwrap() http.ResponseWriter { + return mu.unwrap() +} diff --git a/autobahn_test.go b/autobahn_test.go index 57ceebd5..cd0cc9bb 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -17,11 +17,11 @@ import ( "testing" "time" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/errd" - "nhooyr.io/websocket/internal/test/assert" - "nhooyr.io/websocket/internal/test/wstest" - "nhooyr.io/websocket/internal/util" + "github.com/coder/websocket" + "github.com/coder/websocket/internal/errd" + "github.com/coder/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/wstest" + "github.com/coder/websocket/internal/util" ) var excludedAutobahnCases = []string{ @@ -92,7 +92,7 @@ func TestAutobahn(t *testing.T) { } }) - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) + c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil) assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") diff --git a/ci/bench.sh b/ci/bench.sh index a553b93a..30c06986 100755 --- a/ci/bench.sh +++ b/ci/bench.sh @@ -2,8 +2,19 @@ set -eu cd -- "$(dirname "$0")/.." -go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" . +go test --run=^$ --bench=. --benchmem "$@" ./... +# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test ( cd ./internal/thirdparty - go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" . + go test --run=^$ --bench=. --benchmem "$@" . + + GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" . + if [ "$#" -eq 0 ]; then + if [ "${CI-}" ]; then + sudo apt-get update + sudo apt-get install -y qemu-user-static + ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 + fi + qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem + fi ) diff --git a/ci/fmt.sh b/ci/fmt.sh index 6e5a68e4..588510ba 100755 --- a/ci/fmt.sh +++ b/ci/fmt.sh @@ -2,19 +2,24 @@ set -eu cd -- "$(dirname "$0")/.." +X_TOOLS_VERSION=v0.31.0 + go mod tidy (cd ./internal/thirdparty && go mod tidy) (cd ./internal/examples && go mod tidy) gofmt -w -s . -go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" . +go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" . -npx prettier@3.0.3 \ - --write \ +git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \ + --check \ --log-level=warn \ --print-width=90 \ --no-semi \ --single-quote \ - --arrow-parens=avoid \ - $(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html") + --arrow-parens=avoid + +go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go -go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go +if [ "${CI-}" ]; then + git diff --exit-code +fi diff --git a/ci/lint.sh b/ci/lint.sh index 3cf8eee4..316b035d 100755 --- a/ci/lint.sh +++ b/ci/lint.sh @@ -2,10 +2,13 @@ set -eu cd -- "$(dirname "$0")/.." +STATICCHECK_VERSION=v0.6.1 +GOVULNCHECK_VERSION=v1.1.4 + go vet ./... GOOS=js GOARCH=wasm go vet ./... -go install honnef.co/go/tools/cmd/staticcheck@latest +go install honnef.co/go/tools/cmd/staticcheck@${STATICCHECK_VERSION} staticcheck ./... GOOS=js GOARCH=wasm staticcheck ./... @@ -15,7 +18,7 @@ govulncheck() { cat "$tmpf" fi } -go install golang.org/x/vuln/cmd/govulncheck@latest +go install golang.org/x/vuln/cmd/govulncheck@${GOVULNCHECK_VERSION} govulncheck ./... GOOS=js GOARCH=wasm govulncheck ./... diff --git a/ci/test.sh b/ci/test.sh index 83bb9832..cc3c22d7 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -11,7 +11,20 @@ cd -- "$(dirname "$0")/.." go test "$@" ./... ) -go install github.com/agnivade/wasmbrowsertest@latest +( + GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" . + if [ "$#" -eq 0 ]; then + if [ "${CI-}" ]; then + sudo apt-get update + sudo apt-get install -y qemu-user-static + ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64 + fi + qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask + fi +) + + +go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... sed -i.bak '/stringer\.go/d' ci/out/coverage.prof sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof diff --git a/close.go b/close.go index c3dee7e0..f94951dc 100644 --- a/close.go +++ b/close.go @@ -11,7 +11,7 @@ import ( "net" "time" - "nhooyr.io/websocket/internal/errd" + "github.com/coder/websocket/internal/errd" ) // StatusCode represents a WebSocket status code. @@ -93,85 +93,110 @@ func CloseStatus(err error) StatusCode { // The connection can only be closed once. Additional calls to Close // are no-ops. // -// The maximum length of reason must be 125 bytes. Avoid -// sending a dynamic reason. +// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason. // // Close will unblock all goroutines interacting with the connection once // complete. -func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() - return c.closeHandshake(code, reason) +func (c *Conn) Close(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + if c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } + return net.ErrClosed + } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() + + err = c.closeHandshake(code, reason) + + err2 := c.close() + if err == nil && err2 != nil { + err = err2 + } + + err2 = c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 + } + + return err } // CloseNow closes the WebSocket connection without attempting a close handshake. // Use when you do not want the overhead of the close handshake. func (c *Conn) CloseNow() (err error) { - defer c.wg.Wait() - defer errd.Wrap(&err, "failed to close WebSocket") + defer errd.Wrap(&err, "failed to immediately close WebSocket") - if c.isClosed() { + if c.casClosing() { + err = c.waitGoroutines() + if err != nil { + return err + } return net.ErrClosed } + defer func() { + if errors.Is(err, net.ErrClosed) { + err = nil + } + }() - c.close(nil) - return c.closeErr -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - writeErr := c.writeClose(code, reason) - closeHandshakeErr := c.waitCloseHandshake() + err = c.close() - if writeErr != nil { - return writeErr + err2 := c.waitGoroutines() + if err == nil && err2 != nil { + err = err2 } + return err +} - if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) { - return closeHandshakeErr +func (c *Conn) closeHandshake(code StatusCode, reason string) error { + err := c.writeClose(code, reason) + if err != nil { + return err } + err = c.waitCloseHandshake() + if CloseStatus(err) != code { + return err + } return nil } func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - wroteClose := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if wroteClose { - return net.ErrClosed - } - ce := CloseError{ Code: code, Reason: reason, } var p []byte - var marshalErr error + var err error if ce.Code != StatusNoStatusRcvd { - p, marshalErr = ce.bytes() - } - - writeErr := c.writeControl(context.Background(), opClose, p) - if CloseStatus(writeErr) != -1 { - // Not a real error if it's due to a close frame being received. - writeErr = nil + p, err = ce.bytes() + if err != nil { + return err + } } - // We do this after in case there was an error writing the close frame. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - if marshalErr != nil { - return marshalErr + err = c.writeControl(ctx, opClose, p) + // If the connection closed as we're writing we ignore the error as we might + // have written the close frame, the peer responded and then someone else read it + // and closed the connection. + if err != nil && !errors.Is(err, net.ErrClosed) { + return err } - return writeErr + return nil } func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -181,10 +206,6 @@ func (c *Conn) waitCloseHandshake() error { } defer c.readMu.unlock() - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - for i := int64(0); i < c.msgReader.payloadLength; i++ { _, err := c.br.ReadByte() if err != nil { @@ -207,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error { } } +func (c *Conn) waitGoroutines() error { + t := time.NewTimer(time.Second * 15) + defer t.Stop() + + select { + case <-c.timeoutLoopDone: + case <-t.C: + return errors.New("failed to wait for timeoutLoop goroutine to exit") + } + + c.closeReadMu.Lock() + closeRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closeRead { + select { + case <-c.closeReadDone: + case <-t.C: + return errors.New("failed to wait for close read goroutine to exit") + } + } + + select { + case <-c.closed: + case <-t.C: + return errors.New("failed to wait for connection to be closed") + } + + return nil +} + func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ @@ -277,16 +328,8 @@ func (ce CloseError) bytesErr() ([]byte, error) { return buf, nil } -func (c *Conn) setCloseErr(err error) { - c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil && err != nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) - } +func (c *Conn) casClosing() bool { + return c.closing.Swap(true) } func (c *Conn) isClosed() bool { diff --git a/close_test.go b/close_test.go index 6bf3c256..aec582c1 100644 --- a/close_test.go +++ b/close_test.go @@ -9,7 +9,7 @@ import ( "strings" "testing" - "nhooyr.io/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/assert" ) func TestCloseError(t *testing.T) { diff --git a/compress_test.go b/compress_test.go index 667e1408..d97492cf 100644 --- a/compress_test.go +++ b/compress_test.go @@ -10,8 +10,8 @@ import ( "strings" "testing" - "nhooyr.io/websocket/internal/test/assert" - "nhooyr.io/websocket/internal/test/xrand" + "github.com/coder/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/xrand" ) func Test_slidingWindow(t *testing.T) { diff --git a/conn.go b/conn.go index e133cd67..42fe89fe 100644 --- a/conn.go +++ b/conn.go @@ -6,7 +6,6 @@ package websocket import ( "bufio" "context" - "errors" "fmt" "io" "net" @@ -43,7 +42,7 @@ const ( // This applies to context expirations as well unfortunately. // See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220 type Conn struct { - noCopy + noCopy noCopy subprotocol string rwc io.ReadWriteCloser @@ -53,15 +52,15 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context + readTimeout chan context.Context + writeTimeout chan context.Context + timeoutLoopDone chan struct{} // Read state. - readMu *mu - readHeaderBuf [8]byte - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader // Write state. msgWriter *msgWriter @@ -70,15 +69,25 @@ type Conn struct { writeHeaderBuf [8]byte writeHeader header - wg sync.WaitGroup - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool - - pingCounter int32 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} + // Close handshake state. + closeStateMu sync.RWMutex + closeReceivedErr error + closeSentErr error + + // CloseRead state. + closeReadMu sync.Mutex + closeReadCtx context.Context + closeReadDone chan struct{} + + closing atomic.Bool + closeMu sync.Mutex // Protects following. + closed chan struct{} + + pingCounter atomic.Int64 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) } type connConfig struct { @@ -87,6 +96,8 @@ type connConfig struct { client bool copts *compressionOptions flateThreshold int + onPingReceived func(context.Context, []byte) bool + onPongReceived func(context.Context, []byte) br *bufio.Reader bw *bufio.Writer @@ -103,11 +114,14 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + timeoutLoopDone: make(chan struct{}), - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + onPingReceived: cfg.onPingReceived, + onPongReceived: cfg.onPongReceived, } c.readMu = newMu(c) @@ -128,14 +142,10 @@ func newConn(cfg connConfig) *Conn { } runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) + c.close() }) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.timeoutLoop() - }() + go c.timeoutLoop() return c } @@ -146,35 +156,29 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -func (c *Conn) close(err error) { +func (c *Conn) close() error { c.closeMu.Lock() defer c.closeMu.Unlock() if c.isClosed() { - return - } - if err == nil { - err = c.rwc.Close() + return net.ErrClosed } - c.setCloseErrLocked(err) - - close(c.closed) runtime.SetFinalizer(c, nil) + close(c.closed) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns // closeErr. - c.rwc.Close() - - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.msgWriter.close() - c.msgReader.close() - }() + err := c.rwc.Close() + // With the close of rwc, these become safe to close. + c.msgWriter.close() + c.msgReader.close() + return err } func (c *Conn) timeoutLoop() { + defer close(c.timeoutLoopDone) + readCtx := context.Background() writeCtx := context.Background() @@ -187,14 +191,10 @@ func (c *Conn) timeoutLoop() { case readCtx = <-c.readTimeout: case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - c.wg.Add(1) - go func() { - defer c.wg.Done() - c.writeError(StatusPolicyViolation, errors.New("read timed out")) - }() + c.close() + return case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + c.close() return } } @@ -212,9 +212,9 @@ func (c *Conn) flate() bool { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) + p := c.pingCounter.Add(1) - err := c.ping(ctx, strconv.Itoa(int(p))) + err := c.ping(ctx, strconv.FormatInt(p, 10)) if err != nil { return fmt.Errorf("failed to ping: %w", err) } @@ -243,9 +243,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { case <-c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err + return fmt.Errorf("failed to wait for pong: %w", ctx.Err()) case <-pong: return nil } @@ -281,9 +279,7 @@ func (m *mu) lock(ctx context.Context) error { case <-m.c.closed: return net.ErrClosed case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err + return fmt.Errorf("failed to acquire lock: %w", ctx.Err()) case m.ch <- struct{}{}: // To make sure the connection is certainly alive. // As it's possible the send on m.ch was selected diff --git a/conn_test.go b/conn_test.go index 97b172dc..45bb75be 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "net/http/httptest" "os" @@ -16,13 +17,13 @@ import ( "testing" "time" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/errd" - "nhooyr.io/websocket/internal/test/assert" - "nhooyr.io/websocket/internal/test/wstest" - "nhooyr.io/websocket/internal/test/xrand" - "nhooyr.io/websocket/internal/xsync" - "nhooyr.io/websocket/wsjson" + "github.com/coder/websocket" + "github.com/coder/websocket/internal/errd" + "github.com/coder/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/wstest" + "github.com/coder/websocket/internal/test/xrand" + "github.com/coder/websocket/internal/xsync" + "github.com/coder/websocket/wsjson" ) func TestConn(t *testing.T) { @@ -96,6 +97,85 @@ func TestConn(t *testing.T) { assert.Contains(t, err, "failed to wait for pong") }) + t.Run("pingReceivedPongReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return true + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Success(t, err) + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2) + assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1)) + }) + + t.Run("pingReceivedPongNotReceived", func(t *testing.T) { + var pingReceived1, pongReceived1 bool + var pingReceived2, pongReceived2 bool + tt, c1, c2 := newConnTest(t, + &websocket.DialOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived1 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived1 = true + }, + }, &websocket.AcceptOptions{ + OnPingReceived: func(ctx context.Context, payload []byte) bool { + pingReceived2 = true + return false + }, + OnPongReceived: func(ctx context.Context, payload []byte) { + pongReceived2 = true + }, + }, + ) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) + assert.Contains(t, err, "failed to wait for pong") + + c1.CloseNow() + c2.CloseNow() + + assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2) + assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1)) + }) + t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) @@ -345,6 +425,9 @@ func TestConn(t *testing.T) { func TestWasm(t *testing.T) { t.Parallel() + if os.Getenv("CI") == "" { + t.SkipNow() + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { err := echoServer(w, r, &websocket.AcceptOptions{ @@ -360,8 +443,8 @@ func TestWasm(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v") + cmd.Env = append(cleanEnv(os.Environ()), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL)) b, err := cmd.CombinedOutput() if err != nil { @@ -369,6 +452,18 @@ func TestWasm(t *testing.T) { } } +func cleanEnv(env []string) (out []string) { + for _, e := range env { + // Filter out GITHUB envs and anything with token in it, + // especially GITHUB_TOKEN in CI as it breaks TestWasm. + if strings.HasPrefix(e, "GITHUB") || strings.Contains(e, "TOKEN") { + continue + } + out = append(out, e) + } + return out +} + func assertCloseStatus(exp websocket.StatusCode, err error) error { if websocket.CloseStatus(err) == -1 { return fmt.Errorf("expected websocket.CloseError: %T %v", err, err) @@ -445,7 +540,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } func BenchmarkConn(b *testing.B) { - var benchCases = []struct { + benchCases := []struct { name string mode websocket.CompressionMode }{ @@ -610,3 +705,149 @@ func TestConcurrentClosePing(t *testing.T) { }() } } + +func TestConnClosePropagation(t *testing.T) { + t.Parallel() + + want := []byte("hello") + keepWriting := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + err := c.Write(context.Background(), websocket.MessageText, want) + if err != nil { + return err + } + } + }) + } + keepReading := func(c *websocket.Conn) <-chan error { + return xsync.Go(func() error { + for { + _, got, err := c.Read(context.Background()) + if err != nil { + return err + } + if !bytes.Equal(want, got) { + return fmt.Errorf("unexpected message: want %q, got %q", want, got) + } + } + }) + } + checkReadErr := func(t *testing.T, err error) { + // Check read error (output depends on when read is called in relation to connection closure). + var ce websocket.CloseError + if errors.As(err, &ce) { + assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code) + } else { + assert.ErrorIs(t, net.ErrClosed, err) + } + } + checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) { + for _, c := range conn { + // Check write error. + err := c.Write(context.Background(), websocket.MessageText, want) + assert.ErrorIs(t, net.ErrClosed, err) + + _, _, err = c.Read(context.Background()) + checkReadErr(t, err) + } + } + + t.Run("CloseOtherSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + + _, got, err := other.Read(tt.ctx) + assert.Success(t, err) + assert.Equal(t, "msg", want, got) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringWrite", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = this.CloseRead(tt.ctx) + thisWriteErr := keepWriting(this) + otherReadErr := keepReading(other) + + err := this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisWriteErr: + assert.ErrorIs(t, net.ErrClosed, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseOtherSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + _ = other.CloseRead(tt.ctx) + errs := keepReading(this) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = other.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-errs: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) + t.Run("CloseThisSideDuringRead", func(t *testing.T) { + tt, this, other := newConnTest(t, nil, nil) + + thisReadErr := keepReading(this) + otherReadErr := keepReading(other) + + err := other.Write(tt.ctx, websocket.MessageText, want) + assert.Success(t, err) + + err = this.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + + select { + case err := <-thisReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + select { + case err := <-otherReadErr: + checkReadErr(t, err) + case <-tt.ctx.Done(): + t.Fatal(tt.ctx.Err()) + } + + checkConnErrs(t, this, other) + }) +} diff --git a/dial.go b/dial.go index e4c4daa1..0b11ecbb 100644 --- a/dial.go +++ b/dial.go @@ -17,7 +17,7 @@ import ( "sync" "time" - "nhooyr.io/websocket/internal/errd" + "github.com/coder/websocket/internal/errd" ) // DialOptions represents Dial's options. @@ -48,6 +48,22 @@ type DialOptions struct { // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // for CompressionContextTakeover. CompressionThreshold int + + // OnPingReceived is an optional callback invoked synchronously when a ping frame is received. + // + // The payload contains the application data of the ping frame. + // If the callback returns false, the subsequent pong frame will not be sent. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + OnPingReceived func(ctx context.Context, payload []byte) bool + + // OnPongReceived is an optional callback invoked synchronously when a pong frame is received. + // + // The payload contains the application data of the pong frame. + // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine. + // + // Unlike OnPingReceived, this callback does not return a value because a pong frame + // is a response to a ping and does not trigger any further frame transmission. + OnPongReceived func(ctx context.Context, payload []byte) } func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) { @@ -163,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( client: true, copts: copts, flateThreshold: opts.CompressionThreshold, + onPingReceived: opts.OnPingReceived, + onPongReceived: opts.OnPongReceived, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil diff --git a/dial_test.go b/dial_test.go index 237a2874..f94cd73b 100644 --- a/dial_test.go +++ b/dial_test.go @@ -15,10 +15,10 @@ import ( "testing" "time" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/assert" - "nhooyr.io/websocket/internal/util" - "nhooyr.io/websocket/internal/xsync" + "github.com/coder/websocket" + "github.com/coder/websocket/internal/test/assert" + "github.com/coder/websocket/internal/util" + "github.com/coder/websocket/internal/xsync" ) func TestBadDials(t *testing.T) { diff --git a/doc.go b/doc.go index 2ab648a6..03edf129 100644 --- a/doc.go +++ b/doc.go @@ -15,7 +15,7 @@ // // The wsjson subpackage contain helpers for JSON and protobuf messages. // -// More documentation at https://nhooyr.io/websocket. +// More documentation at https://github.com/coder/websocket. // // # Wasm // @@ -31,4 +31,4 @@ // - Conn.CloseNow is Close(StatusGoingAway, "") // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op // - *http.Response from Dial is &http.Response{} with a 101 status code on success -package websocket // import "nhooyr.io/websocket" +package websocket // import "github.com/coder/websocket" diff --git a/example_test.go b/example_test.go index 590c0411..4cc0cf11 100644 --- a/example_test.go +++ b/example_test.go @@ -6,8 +6,8 @@ import ( "net/http" "time" - "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" ) func ExampleAccept() { diff --git a/export_test.go b/export_test.go index a644d8f0..d3443991 100644 --- a/export_test.go +++ b/export_test.go @@ -6,7 +6,7 @@ package websocket import ( "net" - "nhooyr.io/websocket/internal/util" + "github.com/coder/websocket/internal/util" ) func (c *Conn) RecordBytesWritten() *int { diff --git a/frame.go b/frame.go index 351632fd..e7ab76be 100644 --- a/frame.go +++ b/frame.go @@ -8,9 +8,8 @@ import ( "fmt" "io" "math" - "math/bits" - "nhooyr.io/websocket/internal/errd" + "github.com/coder/websocket/internal/errd" ) // opcode represents a WebSocket opcode. @@ -172,125 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { return nil } - -// mask applies the WebSocket masking algorithm to p -// with the given key. -// See https://tools.ietf.org/html/rfc6455#section-5.3 -// -// The returned value is the correctly rotated key to -// to continue to mask/unmask the message. -// -// It is optimized for LittleEndian and expects the key -// to be in little endian. -// -// See https://github.com/golang/go/issues/31586 -func mask(key uint32, b []byte) uint32 { - if len(b) >= 8 { - key64 := uint64(key)<<32 | uint64(key) - - // At some point in the future we can clean these unrolled loops up. - // See https://github.com/golang/go/issues/31586#issuecomment-487436401 - - // Then we xor until b is less than 128 bytes. - for len(b) >= 128 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - v = binary.LittleEndian.Uint64(b[64:72]) - binary.LittleEndian.PutUint64(b[64:72], v^key64) - v = binary.LittleEndian.Uint64(b[72:80]) - binary.LittleEndian.PutUint64(b[72:80], v^key64) - v = binary.LittleEndian.Uint64(b[80:88]) - binary.LittleEndian.PutUint64(b[80:88], v^key64) - v = binary.LittleEndian.Uint64(b[88:96]) - binary.LittleEndian.PutUint64(b[88:96], v^key64) - v = binary.LittleEndian.Uint64(b[96:104]) - binary.LittleEndian.PutUint64(b[96:104], v^key64) - v = binary.LittleEndian.Uint64(b[104:112]) - binary.LittleEndian.PutUint64(b[104:112], v^key64) - v = binary.LittleEndian.Uint64(b[112:120]) - binary.LittleEndian.PutUint64(b[112:120], v^key64) - v = binary.LittleEndian.Uint64(b[120:128]) - binary.LittleEndian.PutUint64(b[120:128], v^key64) - b = b[128:] - } - - // Then we xor until b is less than 64 bytes. - for len(b) >= 64 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - b = b[64:] - } - - // Then we xor until b is less than 32 bytes. - for len(b) >= 32 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - b = b[32:] - } - - // Then we xor until b is less than 16 bytes. - for len(b) >= 16 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - b = b[16:] - } - - // Then we xor until b is less than 8 bytes. - for len(b) >= 8 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - b = b[8:] - } - } - - // Then we xor until b is less than 4 bytes. - for len(b) >= 4 { - v := binary.LittleEndian.Uint32(b) - binary.LittleEndian.PutUint32(b, v^key) - b = b[4:] - } - - // xor remaining bytes. - for i := range b { - b[i] ^= byte(key) - key = bits.RotateLeft32(key, -8) - } - - return key -} diff --git a/frame_test.go b/frame_test.go index e697e198..08874cb5 100644 --- a/frame_test.go +++ b/frame_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "nhooyr.io/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/assert" ) func TestHeader(t *testing.T) { @@ -97,7 +97,7 @@ func Test_mask(t *testing.T) { key := []byte{0xa, 0xb, 0xc, 0xff} key32 := binary.LittleEndian.Uint32(key) p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - gotKey32 := mask(key32, p) + gotKey32 := mask(p, key32) expP := []byte{0, 0, 0, 0x0d, 0x6} assert.Equal(t, "p", expP, p) diff --git a/go.mod b/go.mod index 715a9f7a..d32fbd77 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module nhooyr.io/websocket +module github.com/coder/websocket -go 1.19 +go 1.23 diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..e69de29b diff --git a/hijack.go b/hijack.go new file mode 100644 index 00000000..9cce45ca --- /dev/null +++ b/hijack.go @@ -0,0 +1,33 @@ +//go:build !js + +package websocket + +import ( + "net/http" +) + +type rwUnwrapper interface { + Unwrap() http.ResponseWriter +} + +// hijacker returns the Hijacker interface of the http.ResponseWriter. +// It follows the Unwrap method of the http.ResponseWriter if available, +// matching the behavior of http.ResponseController. If the Hijacker +// interface is not found, it returns false. +// +// Since the http.ResponseController is not available in Go 1.19, and +// does not support checking the presence of the Hijacker interface, +// this function is used to provide a consistent way to check for the +// Hijacker interface across Go versions. +func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) { + for { + switch t := rw.(type) { + case http.Hijacker: + return t, true + case rwUnwrapper: + rw = t.Unwrap() + default: + return nil, false + } + } +} diff --git a/hijack_go120_test.go b/hijack_go120_test.go new file mode 100644 index 00000000..0f0673a9 --- /dev/null +++ b/hijack_go120_test.go @@ -0,0 +1,38 @@ +//go:build !js && go1.20 + +package websocket + +import ( + "bufio" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/coder/websocket/internal/test/assert" +) + +func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) { + t.Parallel() + + rr := httptest.NewRecorder() + w := mockUnwrapper{ + ResponseWriter: rr, + unwrap: func() http.ResponseWriter { + return mockHijacker{ + ResponseWriter: rr, + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + }, + } + + _, _, err := http.NewResponseController(w).Hijack() + assert.Contains(t, err, "haha") + hj, ok := hijacker(w) + assert.Equal(t, "hijacker found", ok, true) + _, _, err = hj.Hijack() + assert.Contains(t, err, "haha") +} diff --git a/internal/bpool/bpool.go b/internal/bpool/bpool.go index aa826fba..12cf577a 100644 --- a/internal/bpool/bpool.go +++ b/internal/bpool/bpool.go @@ -5,15 +5,16 @@ import ( "sync" ) -var bpool sync.Pool +var bpool = sync.Pool{ + New: func() any { + return &bytes.Buffer{} + }, +} // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { b := bpool.Get() - if b == nil { - return &bytes.Buffer{} - } return b.(*bytes.Buffer) } diff --git a/internal/examples/chat/README.md b/internal/examples/chat/README.md index ca1024a0..4d354586 100644 --- a/internal/examples/chat/README.md +++ b/internal/examples/chat/README.md @@ -1,11 +1,11 @@ # Chat Example -This directory contains a full stack example of a simple chat webapp using nhooyr.io/websocket. +This directory contains a full stack example of a simple chat webapp using github.com/coder/websocket. ```bash $ cd examples/chat $ go run . localhost:0 -listening on http://127.0.0.1:51055 +listening on ws://127.0.0.1:51055 ``` Visit the printed URL to submit and view broadcasted messages in a browser. diff --git a/internal/examples/chat/chat.go b/internal/examples/chat/chat.go index 8b1e30c1..29f304b7 100644 --- a/internal/examples/chat/chat.go +++ b/internal/examples/chat/chat.go @@ -12,7 +12,7 @@ import ( "golang.org/x/time/rate" - "nhooyr.io/websocket" + "github.com/coder/websocket" ) // chatServer enables broadcasting to a set of subscribers. @@ -70,7 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // subscribeHandler accepts the WebSocket connection and then subscribes // it to all future messages. func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) { - err := cs.subscribe(r.Context(), w, r) + err := cs.subscribe(w, r) if errors.Is(err, context.Canceled) { return } @@ -111,7 +111,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) { // // It uses CloseRead to keep reading from the connection to process control // messages and cancel the context if the connection drops. -func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *http.Request) error { +func (cs *chatServer) subscribe(w http.ResponseWriter, r *http.Request) error { var mu sync.Mutex var c *websocket.Conn var closed bool @@ -142,7 +142,7 @@ func (cs *chatServer) subscribe(ctx context.Context, w http.ResponseWriter, r *h mu.Unlock() defer c.CloseNow() - ctx = c.CloseRead(ctx) + ctx := c.CloseRead(context.Background()) for { select { diff --git a/internal/examples/chat/chat_test.go b/internal/examples/chat/chat_test.go index f80f1de1..dcada0b2 100644 --- a/internal/examples/chat/chat_test.go +++ b/internal/examples/chat/chat_test.go @@ -14,7 +14,7 @@ import ( "golang.org/x/time/rate" - "nhooyr.io/websocket" + "github.com/coder/websocket" ) func Test_chatServer(t *testing.T) { @@ -52,7 +52,7 @@ func Test_chatServer(t *testing.T) { // 10 clients are started that send 128 different // messages of max 128 bytes concurrently. // - // The test verifies that every message is seen by ever client + // The test verifies that every message is seen by every client // and no errors occur anywhere. t.Run("concurrency", func(t *testing.T) { t.Parallel() diff --git a/internal/examples/chat/index.html b/internal/examples/chat/index.html index 64edd286..7038342d 100644 --- a/internal/examples/chat/index.html +++ b/internal/examples/chat/index.html @@ -2,7 +2,7 @@ - nhooyr.io/websocket - Chat Example + github.com/coder/websocket - Chat Example diff --git a/internal/examples/chat/main.go b/internal/examples/chat/main.go index 3fcec6be..e3432984 100644 --- a/internal/examples/chat/main.go +++ b/internal/examples/chat/main.go @@ -31,7 +31,7 @@ func run() error { if err != nil { return err } - log.Printf("listening on http://%v", l.Addr()) + log.Printf("listening on ws://%v", l.Addr()) cs := newChatServer() s := &http.Server{ diff --git a/internal/examples/echo/README.md b/internal/examples/echo/README.md index 7f42c3c5..3abbbb57 100644 --- a/internal/examples/echo/README.md +++ b/internal/examples/echo/README.md @@ -1,11 +1,11 @@ # Echo Example -This directory contains a echo server example using nhooyr.io/websocket. +This directory contains a echo server example using github.com/coder/websocket. ```bash $ cd examples/echo $ go run . localhost:0 -listening on http://127.0.0.1:51055 +listening on ws://127.0.0.1:51055 ``` You can use a WebSocket client like https://github.com/hashrocket/ws to connect. All messages diff --git a/internal/examples/echo/main.go b/internal/examples/echo/main.go index 16d78a79..47e30d05 100644 --- a/internal/examples/echo/main.go +++ b/internal/examples/echo/main.go @@ -31,7 +31,7 @@ func run() error { if err != nil { return err } - log.Printf("listening on http://%v", l.Addr()) + log.Printf("listening on ws://%v", l.Addr()) s := &http.Server{ Handler: echoServer{ diff --git a/internal/examples/echo/server.go b/internal/examples/echo/server.go index 246ad582..37e2f2c4 100644 --- a/internal/examples/echo/server.go +++ b/internal/examples/echo/server.go @@ -9,7 +9,7 @@ import ( "golang.org/x/time/rate" - "nhooyr.io/websocket" + "github.com/coder/websocket" ) // echoServer is the WebSocket echo server implementation. @@ -37,7 +37,7 @@ func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) for { - err = echo(r.Context(), c, l) + err = echo(c, l) if websocket.CloseStatus(err) == websocket.StatusNormalClosure { return } @@ -51,8 +51,8 @@ func (s echoServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { // echo reads from the WebSocket connection and then writes // the received message back to it. // The entire function has 10s to complete. -func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { - ctx, cancel := context.WithTimeout(ctx, time.Second*10) +func echo(c *websocket.Conn, l *rate.Limiter) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() err := l.Wait(ctx) diff --git a/internal/examples/echo/server_test.go b/internal/examples/echo/server_test.go index 9b608301..81e8cfc2 100644 --- a/internal/examples/echo/server_test.go +++ b/internal/examples/echo/server_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" ) // Test_echoServer tests the echoServer by sending it 5 different messages diff --git a/internal/examples/go.mod b/internal/examples/go.mod index c98b81ce..e368b76b 100644 --- a/internal/examples/go.mod +++ b/internal/examples/go.mod @@ -1,10 +1,10 @@ -module nhooyr.io/websocket/examples +module github.com/coder/websocket/examples -go 1.19 +go 1.23 -replace nhooyr.io/websocket => ../.. +replace github.com/coder/websocket => ../.. require ( - golang.org/x/time v0.3.0 - nhooyr.io/websocket v0.0.0-00010101000000-000000000000 + github.com/coder/websocket v0.0.0-00010101000000-000000000000 + golang.org/x/time v0.7.0 ) diff --git a/internal/examples/go.sum b/internal/examples/go.sum index f8a07e82..60aa8f9a 100644 --- a/internal/examples/go.sum +++ b/internal/examples/go.sum @@ -1,2 +1,2 @@ -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= +golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/internal/test/wstest/echo.go b/internal/test/wstest/echo.go index dc21a8f0..c0c8dcd7 100644 --- a/internal/test/wstest/echo.go +++ b/internal/test/wstest/echo.go @@ -7,9 +7,9 @@ import ( "io" "time" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/xrand" - "nhooyr.io/websocket/internal/xsync" + "github.com/coder/websocket" + "github.com/coder/websocket/internal/test/xrand" + "github.com/coder/websocket/internal/xsync" ) // EchoLoop echos every msg received from c until an error diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index 8e1deb47..b8cf094d 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -10,7 +10,7 @@ import ( "net/http" "net/http/httptest" - "nhooyr.io/websocket" + "github.com/coder/websocket" ) // Pipe is used to create an in memory connection diff --git a/internal/thirdparty/frame_test.go b/internal/thirdparty/frame_test.go index 1a0ed125..75b05291 100644 --- a/internal/thirdparty/frame_test.go +++ b/internal/thirdparty/frame_test.go @@ -2,17 +2,19 @@ package thirdparty import ( "encoding/binary" + "runtime" "strconv" "testing" _ "unsafe" "github.com/gobwas/ws" _ "github.com/gorilla/websocket" + _ "github.com/lesismal/nbio/nbhttp/websocket" - _ "nhooyr.io/websocket" + _ "github.com/coder/websocket" ) -func basicMask(maskKey [4]byte, pos int, b []byte) int { +func basicMask(b []byte, maskKey [4]byte, pos int) int { for i := range b { b[i] ^= maskKey[pos&3] pos++ @@ -20,23 +22,34 @@ func basicMask(maskKey [4]byte, pos int, b []byte) int { return pos & 3 } +//go:linkname maskGo github.com/coder/websocket.maskGo +func maskGo(b []byte, key32 uint32) int + +//go:linkname maskAsm github.com/coder/websocket.maskAsm +func maskAsm(b *byte, len int, key32 uint32) uint32 + +//go:linkname nbioMaskBytes github.com/lesismal/nbio/nbhttp/websocket.maskXOR +func nbioMaskBytes(b, key []byte) int + //go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes func gorillaMaskBytes(key [4]byte, pos int, b []byte) int -//go:linkname mask nhooyr.io/websocket.mask -func mask(key32 uint32, b []byte) int - func Benchmark_mask(b *testing.B) { + b.Run(runtime.GOARCH, benchmark_mask) +} + +func benchmark_mask(b *testing.B) { sizes := []int{ - 2, - 3, - 4, 8, 16, 32, 128, + 256, 512, + 1024, + 2048, 4096, + 8192, 16384, } @@ -48,22 +61,34 @@ func Benchmark_mask(b *testing.B) { name: "basic", fn: func(b *testing.B, key [4]byte, p []byte) { for i := 0; i < b.N; i++ { - basicMask(key, 0, p) + basicMask(p, key, 0) } }, }, { - name: "nhooyr", + name: "nhooyr-go", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.LittleEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + maskGo(p, key32) + } + }, + }, + { + name: "wdvxdr1123-asm", fn: func(b *testing.B, key [4]byte, p []byte) { key32 := binary.LittleEndian.Uint32(key[:]) b.ResetTimer() for i := 0; i < b.N; i++ { - mask(key32, p) + maskAsm(&p[0], len(p), key32) } }, }, + { name: "gorilla", fn: func(b *testing.B, key [4]byte, p []byte) { @@ -80,16 +105,25 @@ func Benchmark_mask(b *testing.B) { } }, }, + { + name: "nbio", + fn: func(b *testing.B, key [4]byte, p []byte) { + keyb := key[:] + for i := 0; i < b.N; i++ { + nbioMaskBytes(p, keyb) + } + }, + }, } key := [4]byte{1, 2, 3, 4} - for _, size := range sizes { - p := make([]byte, size) + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + for _, size := range sizes { + p := make([]byte, size) - b.Run(strconv.Itoa(size), func(b *testing.B) { - for _, fn := range fns { - b.Run(fn.name, func(b *testing.B) { + b.Run(strconv.Itoa(size), func(b *testing.B) { b.SetBytes(int64(size)) fn.fn(b, key, p) diff --git a/internal/thirdparty/gin_test.go b/internal/thirdparty/gin_test.go index 6d59578d..bd30ebdd 100644 --- a/internal/thirdparty/gin_test.go +++ b/internal/thirdparty/gin_test.go @@ -10,11 +10,11 @@ import ( "github.com/gin-gonic/gin" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/errd" - "nhooyr.io/websocket/internal/test/assert" - "nhooyr.io/websocket/internal/test/wstest" - "nhooyr.io/websocket/wsjson" + "github.com/coder/websocket" + "github.com/coder/websocket/internal/errd" + "github.com/coder/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/wstest" + "github.com/coder/websocket/wsjson" ) func TestGin(t *testing.T) { diff --git a/internal/thirdparty/go.mod b/internal/thirdparty/go.mod index 10eb45c1..7a86aca9 100644 --- a/internal/thirdparty/go.mod +++ b/internal/thirdparty/go.mod @@ -1,41 +1,45 @@ -module nhooyr.io/websocket/internal/thirdparty +module github.com/coder/websocket/internal/thirdparty -go 1.19 +go 1.23 -replace nhooyr.io/websocket => ../.. +replace github.com/coder/websocket => ../.. require ( - github.com/gin-gonic/gin v1.9.1 - github.com/gobwas/ws v1.3.0 - github.com/gorilla/websocket v1.5.0 - nhooyr.io/websocket v0.0.0-00010101000000-000000000000 + github.com/coder/websocket v0.0.0-00010101000000-000000000000 + github.com/gin-gonic/gin v1.10.0 + github.com/gobwas/ws v1.4.0 + github.com/gorilla/websocket v1.5.3 + github.com/lesismal/nbio v1.5.12 ) require ( - github.com/bytedance/sonic v1.9.1 // indirect - github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect - github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cloudwego/base64x v0.1.4 // indirect + github.com/cloudwego/iasm v0.2.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/pool v0.2.1 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/cpuid/v2 v2.2.4 // indirect - github.com/leodido/go-urn v1.2.4 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/lesismal/llib v1.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - github.com/ugorji/go/codec v1.2.11 // indirect - golang.org/x/arch v0.3.0 // indirect - golang.org/x/crypto v0.9.0 // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect - golang.org/x/text v0.9.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + github.com/ugorji/go/codec v1.2.12 // indirect + golang.org/x/arch v0.8.0 // indirect + golang.org/x/crypto v0.23.0 // indirect + golang.org/x/net v0.25.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.15.0 // indirect + google.golang.org/protobuf v1.34.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/internal/thirdparty/go.sum b/internal/thirdparty/go.sum index a9424b8d..a7be7082 100644 --- a/internal/thirdparty/go.sum +++ b/internal/thirdparty/go.sum @@ -1,93 +1,110 @@ -github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= -github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= -github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= -github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= -github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= -github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= -github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= -github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= -github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0= -github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= +github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs= +github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= -github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= -github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= -github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lesismal/llib v1.1.13 h1:+w1+t0PykXpj2dXQck0+p6vdC9/mnbEXHgUy/HXDGfE= +github.com/lesismal/llib v1.1.13/go.mod h1:70tFXXe7P1FZ02AU9l8LgSOK7d7sRrpnkUr3rd3gKSg= +github.com/lesismal/nbio v1.5.12 h1:YcUjjmOvmKEANs6Oo175JogXvHy8CuE7i6ccjM2/tv4= +github.com/lesismal/nbio v1.5.12/go.mod h1:QsxE0fKFe1PioyjuHVDn2y8ktYK7xv9MFbpkoRFj8vI= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= -github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= -github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= -golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20210513122933-cd7d49e622d5/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/internal/xsync/go_test.go b/internal/xsync/go_test.go index dabea8a5..a3f7053b 100644 --- a/internal/xsync/go_test.go +++ b/internal/xsync/go_test.go @@ -3,7 +3,7 @@ package xsync import ( "testing" - "nhooyr.io/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/assert" ) func TestGoRecover(t *testing.T) { diff --git a/internal/xsync/int64.go b/internal/xsync/int64.go deleted file mode 100644 index a0c40204..00000000 --- a/internal/xsync/int64.go +++ /dev/null @@ -1,23 +0,0 @@ -package xsync - -import ( - "sync/atomic" -) - -// Int64 represents an atomic int64. -type Int64 struct { - // We do not use atomic.Load/StoreInt64 since it does not - // work on 32 bit computers but we need 64 bit integers. - i atomic.Value -} - -// Load loads the int64. -func (v *Int64) Load() int64 { - i, _ := v.i.Load().(int64) - return i -} - -// Store stores the int64. -func (v *Int64) Store(i int64) { - v.i.Store(i) -} diff --git a/make.sh b/make.sh deleted file mode 100755 index 170d00a8..00000000 --- a/make.sh +++ /dev/null @@ -1,12 +0,0 @@ -#!/bin/sh -set -eu -cd -- "$(dirname "$0")" - -echo "=== fmt.sh" -./ci/fmt.sh -echo "=== lint.sh" -./ci/lint.sh -echo "=== test.sh" -./ci/test.sh "$@" -echo "=== bench.sh" -./ci/bench.sh diff --git a/mask.go b/mask.go new file mode 100644 index 00000000..7bc0c8d5 --- /dev/null +++ b/mask.go @@ -0,0 +1,128 @@ +package websocket + +import ( + "encoding/binary" + "math/bits" +) + +// maskGo applies the WebSocket masking algorithm to p +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +// +// See https://github.com/golang/go/issues/31586 +func maskGo(b []byte, key uint32) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + v = binary.LittleEndian.Uint64(b[64:72]) + binary.LittleEndian.PutUint64(b[64:72], v^key64) + v = binary.LittleEndian.Uint64(b[72:80]) + binary.LittleEndian.PutUint64(b[72:80], v^key64) + v = binary.LittleEndian.Uint64(b[80:88]) + binary.LittleEndian.PutUint64(b[80:88], v^key64) + v = binary.LittleEndian.Uint64(b[88:96]) + binary.LittleEndian.PutUint64(b[88:96], v^key64) + v = binary.LittleEndian.Uint64(b[96:104]) + binary.LittleEndian.PutUint64(b[96:104], v^key64) + v = binary.LittleEndian.Uint64(b[104:112]) + binary.LittleEndian.PutUint64(b[104:112], v^key64) + v = binary.LittleEndian.Uint64(b[112:120]) + binary.LittleEndian.PutUint64(b[112:120], v^key64) + v = binary.LittleEndian.Uint64(b[120:128]) + binary.LittleEndian.PutUint64(b[120:128], v^key64) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + b = b[8:] + } + } + + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + + // xor remaining bytes. + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + + return key +} diff --git a/mask_amd64.s b/mask_amd64.s new file mode 100644 index 00000000..bd42be31 --- /dev/null +++ b/mask_amd64.s @@ -0,0 +1,127 @@ +#include "textflag.h" + +// func maskAsm(b *byte, len int, key uint32) +TEXT ·maskAsm(SB), NOSPLIT, $0-28 + // AX = b + // CX = len (left length) + // SI = key (uint32) + // DI = uint64(SI) | uint64(SI)<<32 + MOVQ b+0(FP), AX + MOVQ len+8(FP), CX + MOVL key+16(FP), SI + + // calculate the DI + // DI = SI<<32 | SI + MOVL SI, DI + MOVQ DI, DX + SHLQ $32, DI + ORQ DX, DI + + CMPQ CX, $15 + JLE less_than_16 + CMPQ CX, $63 + JLE less_than_64 + CMPQ CX, $128 + JLE sse + TESTQ $31, AX + JNZ unaligned + +unaligned_loop_1byte: + XORB SI, (AX) + INCQ AX + DECQ CX + ROLL $24, SI + TESTQ $7, AX + JNZ unaligned_loop_1byte + + // calculate DI again since SI was modified + // DI = SI<<32 | SI + MOVL SI, DI + MOVQ DI, DX + SHLQ $32, DI + ORQ DX, DI + + TESTQ $31, AX + JZ sse + +unaligned: + TESTQ $7, AX // AND $7 & len, if not zero jump to loop_1b. + JNZ unaligned_loop_1byte + +unaligned_loop: + // we don't need to check the CX since we know it's above 128 + XORQ DI, (AX) + ADDQ $8, AX + SUBQ $8, CX + TESTQ $31, AX + JNZ unaligned_loop + JMP sse + +sse: + CMPQ CX, $0x40 + JL less_than_64 + MOVQ DI, X0 + PUNPCKLQDQ X0, X0 + +sse_loop: + MOVOU 0*16(AX), X1 + MOVOU 1*16(AX), X2 + MOVOU 2*16(AX), X3 + MOVOU 3*16(AX), X4 + PXOR X0, X1 + PXOR X0, X2 + PXOR X0, X3 + PXOR X0, X4 + MOVOU X1, 0*16(AX) + MOVOU X2, 1*16(AX) + MOVOU X3, 2*16(AX) + MOVOU X4, 3*16(AX) + ADDQ $0x40, AX + SUBQ $0x40, CX + CMPQ CX, $0x40 + JAE sse_loop + +less_than_64: + TESTQ $32, CX + JZ less_than_32 + XORQ DI, (AX) + XORQ DI, 8(AX) + XORQ DI, 16(AX) + XORQ DI, 24(AX) + ADDQ $32, AX + +less_than_32: + TESTQ $16, CX + JZ less_than_16 + XORQ DI, (AX) + XORQ DI, 8(AX) + ADDQ $16, AX + +less_than_16: + TESTQ $8, CX + JZ less_than_8 + XORQ DI, (AX) + ADDQ $8, AX + +less_than_8: + TESTQ $4, CX + JZ less_than_4 + XORL SI, (AX) + ADDQ $4, AX + +less_than_4: + TESTQ $2, CX + JZ less_than_2 + XORW SI, (AX) + ROLL $16, SI + ADDQ $2, AX + +less_than_2: + TESTQ $1, CX + JZ done + XORB SI, (AX) + ROLL $24, SI + +done: + MOVL SI, ret+24(FP) + RET diff --git a/mask_arm64.s b/mask_arm64.s new file mode 100644 index 00000000..e494b43a --- /dev/null +++ b/mask_arm64.s @@ -0,0 +1,72 @@ +#include "textflag.h" + +// func maskAsm(b *byte, len int, key uint32) +TEXT ·maskAsm(SB), NOSPLIT, $0-28 + // R0 = b + // R1 = len + // R3 = key (uint32) + // R2 = uint64(key)<<32 | uint64(key) + MOVD b_ptr+0(FP), R0 + MOVD b_len+8(FP), R1 + MOVWU key+16(FP), R3 + MOVD R3, R2 + ORR R2<<32, R2, R2 + VDUP R2, V0.D2 + CMP $64, R1 + BLT less_than_64 + +loop_64: + VLD1 (R0), [V1.B16, V2.B16, V3.B16, V4.B16] + VEOR V1.B16, V0.B16, V1.B16 + VEOR V2.B16, V0.B16, V2.B16 + VEOR V3.B16, V0.B16, V3.B16 + VEOR V4.B16, V0.B16, V4.B16 + VST1.P [V1.B16, V2.B16, V3.B16, V4.B16], 64(R0) + SUBS $64, R1 + CMP $64, R1 + BGE loop_64 + +less_than_64: + CBZ R1, end + TBZ $5, R1, less_than_32 + VLD1 (R0), [V1.B16, V2.B16] + VEOR V1.B16, V0.B16, V1.B16 + VEOR V2.B16, V0.B16, V2.B16 + VST1.P [V1.B16, V2.B16], 32(R0) + +less_than_32: + TBZ $4, R1, less_than_16 + LDP (R0), (R11, R12) + EOR R11, R2, R11 + EOR R12, R2, R12 + STP.P (R11, R12), 16(R0) + +less_than_16: + TBZ $3, R1, less_than_8 + MOVD (R0), R11 + EOR R2, R11, R11 + MOVD.P R11, 8(R0) + +less_than_8: + TBZ $2, R1, less_than_4 + MOVWU (R0), R11 + EORW R2, R11, R11 + MOVWU.P R11, 4(R0) + +less_than_4: + TBZ $1, R1, less_than_2 + MOVHU (R0), R11 + EORW R3, R11, R11 + MOVHU.P R11, 2(R0) + RORW $16, R3 + +less_than_2: + TBZ $0, R1, end + MOVBU (R0), R11 + EORW R3, R11, R11 + MOVBU.P R11, 1(R0) + RORW $8, R3 + +end: + MOVWU R3, ret+24(FP) + RET diff --git a/mask_asm.go b/mask_asm.go new file mode 100644 index 00000000..f9484b5b --- /dev/null +++ b/mask_asm.go @@ -0,0 +1,26 @@ +//go:build amd64 || arm64 + +package websocket + +func mask(b []byte, key uint32) uint32 { + // TODO: Will enable in v1.9.0. + return maskGo(b, key) + /* + if len(b) > 0 { + return maskAsm(&b[0], len(b), key) + } + return key + */ +} + +// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this +// function are perfect. There are almost certainly missing optimizations or +// opportunities for simplification. I'm confident there are no bugs though. +// For example, the arm64 implementation doesn't align memory like the amd64. +// Or the amd64 implementation could use AVX512 instead of just AVX2. +// The AVX2 code I had to disable anyway as it wasn't performing as expected. +// See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049 +// +//go:noescape +//lint:ignore U1000 disabled till v1.9.0 +func maskAsm(b *byte, len int, key uint32) uint32 diff --git a/mask_asm_test.go b/mask_asm_test.go new file mode 100644 index 00000000..416cbc43 --- /dev/null +++ b/mask_asm_test.go @@ -0,0 +1,11 @@ +//go:build amd64 || arm64 + +package websocket + +import "testing" + +func TestMaskASM(t *testing.T) { + t.Parallel() + + testMask(t, "maskASM", mask) +} diff --git a/mask_go.go b/mask_go.go new file mode 100644 index 00000000..b29435e9 --- /dev/null +++ b/mask_go.go @@ -0,0 +1,7 @@ +//go:build !amd64 && !arm64 && !js + +package websocket + +func mask(b []byte, key uint32) uint32 { + return maskGo(b, key) +} diff --git a/mask_test.go b/mask_test.go new file mode 100644 index 00000000..00a9f0a2 --- /dev/null +++ b/mask_test.go @@ -0,0 +1,73 @@ +package websocket + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "math/big" + "math/bits" + "testing" + + "github.com/coder/websocket/internal/test/assert" +) + +func basicMask(b []byte, key uint32) uint32 { + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + return key +} + +func basicMask2(b []byte, key uint32) uint32 { + keyb := binary.LittleEndian.AppendUint32(nil, key) + pos := 0 + for i := range b { + b[i] ^= keyb[pos&3] + pos++ + } + return bits.RotateLeft32(key, (pos&3)*-8) +} + +func TestMask(t *testing.T) { + t.Parallel() + + testMask(t, "basicMask", basicMask) + testMask(t, "maskGo", maskGo) + testMask(t, "basicMask2", basicMask2) +} + +func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) { + t.Run(name, func(t *testing.T) { + t.Parallel() + for i := 0; i < 9999; i++ { + keyb := make([]byte, 4) + _, err := rand.Read(keyb) + assert.Success(t, err) + key := binary.LittleEndian.Uint32(keyb) + + n, err := rand.Int(rand.Reader, big.NewInt(1<<16)) + assert.Success(t, err) + + b := make([]byte, 1+n.Int64()) + _, err = rand.Read(b) + assert.Success(t, err) + + b2 := make([]byte, len(b)) + copy(b2, b) + b3 := make([]byte, len(b)) + copy(b3, b) + + key2 := basicMask(b2, key) + key3 := fn(b3, key) + + if key2 != key3 { + t.Errorf("expected key %X but got %X", key2, key3) + } + if !bytes.Equal(b2, b3) { + t.Error("bad bytes") + return + } + } + }) +} diff --git a/netconn.go b/netconn.go index 1667f45c..b118e4d3 100644 --- a/netconn.go +++ b/netconn.go @@ -68,7 +68,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.writeMu.unlock() // Prevents future writes from writing until the deadline is reset. - atomic.StoreInt64(&nc.writeExpired, 1) + nc.writeExpired.Store(1) }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C @@ -84,7 +84,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.readMu.unlock() // Prevents future reads from reading until the deadline is reset. - atomic.StoreInt64(&nc.readExpired, 1) + nc.readExpired.Store(1) }) if !nc.readTimer.Stop() { <-nc.readTimer.C @@ -99,13 +99,13 @@ type netConn struct { writeTimer *time.Timer writeMu *mu - writeExpired int64 + writeExpired atomic.Int64 writeCtx context.Context writeCancel context.CancelFunc readTimer *time.Timer readMu *mu - readExpired int64 + readExpired atomic.Int64 readCtx context.Context readCancel context.CancelFunc readEOFed bool @@ -126,7 +126,7 @@ func (nc *netConn) Write(p []byte) (int, error) { nc.writeMu.forceLock() defer nc.writeMu.unlock() - if atomic.LoadInt64(&nc.writeExpired) == 1 { + if nc.writeExpired.Load() == 1 { return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) } @@ -154,7 +154,7 @@ func (nc *netConn) Read(p []byte) (int, error) { } func (nc *netConn) read(p []byte) (int, error) { - if atomic.LoadInt64(&nc.readExpired) == 1 { + if nc.readExpired.Load() == 1 { return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) } @@ -206,7 +206,7 @@ func (nc *netConn) SetDeadline(t time.Time) error { } func (nc *netConn) SetWriteDeadline(t time.Time) error { - atomic.StoreInt64(&nc.writeExpired, 0) + nc.writeExpired.Store(0) if t.IsZero() { nc.writeTimer.Stop() } else { @@ -220,7 +220,7 @@ func (nc *netConn) SetWriteDeadline(t time.Time) error { } func (nc *netConn) SetReadDeadline(t time.Time) error { - atomic.StoreInt64(&nc.readExpired, 0) + nc.readExpired.Store(0) if t.IsZero() { nc.readTimer.Stop() } else { diff --git a/read.go b/read.go index 8742842e..2db22435 100644 --- a/read.go +++ b/read.go @@ -11,11 +11,11 @@ import ( "io" "net" "strings" + "sync/atomic" "time" - "nhooyr.io/websocket/internal/errd" - "nhooyr.io/websocket/internal/util" - "nhooyr.io/websocket/internal/xsync" + "github.com/coder/websocket/internal/errd" + "github.com/coder/websocket/internal/util" ) // Reader reads from the connection until there is a WebSocket @@ -60,14 +60,24 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { // Call CloseRead when you do not expect to read any more messages. // Since it actively reads from the connection, it will ensure that ping, pong and close // frames are responded to. This means c.Ping and c.Close will still work as expected. +// +// This function is idempotent. func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) + c.closeReadCtx = ctx + c.closeReadDone = make(chan struct{}) + c.closeReadMu.Unlock() - c.wg.Add(1) go func() { - defer c.CloseNow() - defer c.wg.Done() + defer close(c.closeReadDone) defer cancel() + defer c.close() _, _, err := c.Reader(ctx) if err == nil { c.Close(StatusPolicyViolation, "unexpected data message") @@ -207,60 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } } -func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { +// prepareRead sets the readTimeout context and returns a done function +// to be called after the read is done. It also returns an error if the +// connection is closed. The reference to the error is used to assign +// an error depending on if the connection closed or the context timed +// out during use. Typically the referenced error is a named return +// variable of the function calling this method. +func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { select { case <-c.closed: - return header{}, net.ErrClosed + return nil, net.ErrClosed case c.readTimeout <- ctx: } - h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) - if err != nil { + done := func() { select { case <-c.closed: - return header{}, net.ErrClosed - case <-ctx.Done(): - return header{}, ctx.Err() - default: - c.close(err) - return header{}, err + if *err != nil { + *err = net.ErrClosed + } + case c.readTimeout <- context.Background(): + } + if *err != nil && ctx.Err() != nil { + *err = ctx.Err() } } - select { - case <-c.closed: - return header{}, net.ErrClosed - case c.readTimeout <- context.Background(): + c.closeStateMu.Lock() + closeReceivedErr := c.closeReceivedErr + c.closeStateMu.Unlock() + if closeReceivedErr != nil { + defer done() + return nil, closeReceivedErr } - return h, nil + return done, nil } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { - select { - case <-c.closed: - return 0, net.ErrClosed - case c.readTimeout <- ctx: +func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return header{}, err } + defer readDone() - n, err := io.ReadFull(c.br, p) + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { - select { - case <-c.closed: - return n, net.ErrClosed - case <-ctx.Done(): - return n, ctx.Err() - default: - err = fmt.Errorf("failed to read frame payload: %w", err) - c.close(err) - return n, err - } + return header{}, err } - select { - case <-c.closed: - return n, net.ErrClosed - case c.readTimeout <- context.Background(): + return h, nil +} + +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { + readDone, err := c.prepareRead(ctx, &err) + if err != nil { + return 0, err + } + defer readDone() + + n, err := io.ReadFull(c.br, p) + if err != nil { + return n, fmt.Errorf("failed to read frame payload: %w", err) } return n, err @@ -289,13 +307,21 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } if h.masked { - mask(h.maskKey, b) + mask(b, h.maskKey) } switch h.opcode { case opPing: + if c.onPingReceived != nil { + if !c.onPingReceived(ctx, b) { + return nil + } + } return c.writeControl(ctx, opPong, b) case opPong: + if c.onPongReceived != nil { + c.onPongReceived(ctx, b) + } c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock() @@ -308,9 +334,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { return nil } - defer func() { - c.readCloseFrameErr = err - }() + // opClose ce, err := parseClosePayload(b) if err != nil { @@ -320,9 +344,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { } err = fmt.Errorf("received close frame: %w", ce) - c.setCloseErr(err) - c.writeClose(ce.Code, ce.Reason) - c.close(err) + c.closeStateMu.Lock() + c.closeReceivedErr = err + closeSent := c.closeSentErr != nil + c.closeStateMu.Unlock() + + // Only unlock readMu if this connection is being closed becaue + // c.close will try to acquire the readMu lock. We unlock for + // writeClose as well because it may also call c.close. + if !closeSent { + c.readMu.unlock() + _ = c.writeClose(ce.Code, ce.Reason) + } + if !c.casClosing() { + c.readMu.unlock() + _ = c.close() + } return err } @@ -336,9 +373,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.unlock() if !c.msgReader.fin { - err = errors.New("previous message not read to completion") - c.close(fmt.Errorf("failed to get reader: %w", err)) - return 0, nil, err + return 0, nil, errors.New("previous message not read to completion") } h, err := c.readLoop(ctx) @@ -411,10 +446,9 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { return n, io.EOF } if err != nil { - err = fmt.Errorf("failed to read: %w", err) - mr.c.close(err) + return n, fmt.Errorf("failed to read: %w", err) } - return n, err + return n, nil } func (mr *msgReader) read(p []byte) (int, error) { @@ -453,7 +487,7 @@ func (mr *msgReader) read(p []byte) (int, error) { mr.payloadLength -= int64(n) if !mr.c.client { - mr.maskKey = mask(mr.maskKey, p) + mr.maskKey = mask(p, mr.maskKey) } return n, nil @@ -463,7 +497,7 @@ func (mr *msgReader) read(p []byte) (int, error) { type limitReader struct { c *Conn r io.Reader - limit xsync.Int64 + limit atomic.Int64 n int64 } diff --git a/write.go b/write.go index 7b1152ce..7324de74 100644 --- a/write.go +++ b/write.go @@ -5,6 +5,7 @@ package websocket import ( "bufio" + "compress/flate" "context" "crypto/rand" "encoding/binary" @@ -14,10 +15,8 @@ import ( "net" "time" - "compress/flate" - - "nhooyr.io/websocket/internal/errd" - "nhooyr.io/websocket/internal/util" + "github.com/coder/websocket/internal/errd" + "github.com/coder/websocket/internal/util" ) // Writer returns a writer bounded by the context that will write @@ -159,7 +158,6 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to write: %w", err) - mw.c.close(err) } }() @@ -242,49 +240,44 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error return nil } -// frame handles all writes to the connection. +// writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { err = c.writeFrameMu.lock(ctx) if err != nil { return 0, err } + defer c.writeFrameMu.unlock() - // If the state says a close has already been written, we wait until - // the connection is closed and return that error. - // - // However, if the frame being written is a close, that means its the close from - // the state being set so we let it go through. - c.closeMu.Lock() - wroteClose := c.wroteClose - c.closeMu.Unlock() - if wroteClose && opcode != opClose { - c.writeFrameMu.unlock() - select { - case <-ctx.Done(): - return 0, ctx.Err() - case <-c.closed: - return 0, net.ErrClosed + defer func() { + if c.isClosed() && opcode == opClose { + err = nil } + if err != nil { + if ctx.Err() != nil { + err = ctx.Err() + } else if c.isClosed() { + err = net.ErrClosed + } + err = fmt.Errorf("failed to write frame: %w", err) + } + }() + + c.closeStateMu.Lock() + closeSentErr := c.closeSentErr + c.closeStateMu.Unlock() + if closeSentErr != nil { + return 0, net.ErrClosed } - defer c.writeFrameMu.unlock() select { case <-c.closed: return 0, net.ErrClosed case c.writeTimeout <- ctx: } - defer func() { - if err != nil { - select { - case <-c.closed: - err = net.ErrClosed - case <-ctx.Done(): - err = ctx.Err() - default: - } - c.close(err) - err = fmt.Errorf("failed to write frame: %w", err) + select { + case <-c.closed: + case c.writeTimeout <- context.Background(): } }() @@ -323,13 +316,16 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco } } - select { - case <-c.closed: - if opcode == opClose { - return n, nil + if opcode == opClose { + c.closeStateMu.Lock() + c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed) + closeReceived := c.closeReceivedErr != nil + c.closeStateMu.Unlock() + + if closeReceived && !c.casClosing() { + c.writeFrameMu.unlock() + _ = c.close() } - return n, net.ErrClosed - case c.writeTimeout <- context.Background(): } return n, nil @@ -365,7 +361,7 @@ func (c *Conn) writeFramePayload(p []byte) (n int, err error) { return n, err } - maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) + maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey) p = p[j:] n += j @@ -392,7 +388,5 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { } func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) c.writeClose(code, err.Error()) - c.close(nil) } diff --git a/ws_js.go b/ws_js.go index b4011b5c..5e324c47 100644 --- a/ws_js.go +++ b/ws_js.go @@ -1,4 +1,4 @@ -package websocket // import "nhooyr.io/websocket" +package websocket // import "github.com/coder/websocket" import ( "bytes" @@ -12,11 +12,11 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "syscall/js" - "nhooyr.io/websocket/internal/bpool" - "nhooyr.io/websocket/internal/wsjs" - "nhooyr.io/websocket/internal/xsync" + "github.com/coder/websocket/internal/bpool" + "github.com/coder/websocket/internal/wsjs" ) // opcode represents a WebSocket opcode. @@ -41,15 +41,16 @@ const ( // Conn provides a wrapper around the browser WebSocket API. type Conn struct { - noCopy - ws wsjs.WebSocket + noCopy noCopy + ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit xsync.Int64 + msgReadLimit atomic.Int64 + + closeReadMu sync.Mutex + closeReadCtx context.Context - wg sync.WaitGroup closingMu sync.Mutex - isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -130,7 +131,10 @@ func (c *Conn) closeWithInternal() { // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - if c.isReadClosed.Load() == 1 { + c.closeReadMu.Lock() + closedRead := c.closeReadCtx != nil + c.closeReadMu.Unlock() + if closedRead { return 0, nil, errors.New("WebSocket connection read closed") } @@ -138,7 +142,8 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if err != nil { return 0, nil, fmt.Errorf("failed to read: %w", err) } - if int64(len(p)) > c.msgReadLimit.Load() { + readLimit := c.msgReadLimit.Load() + if readLimit >= 0 && int64(len(p)) > readLimit { err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err @@ -224,7 +229,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { - defer c.wg.Wait() err := c.exportedClose(code, reason) if err != nil { return fmt.Errorf("failed to close WebSocket: %w", err) @@ -238,7 +242,6 @@ func (c *Conn) Close(code StatusCode, reason string) error { // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close // a WebSocket without the close handshake. func (c *Conn) CloseNow() error { - defer c.wg.Wait() return c.Close(StatusGoingAway, "") } @@ -388,14 +391,19 @@ func (w *writer) Close() error { // CloseRead implements *Conn.CloseRead for wasm. func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - + c.closeReadMu.Lock() + ctx2 := c.closeReadCtx + if ctx2 != nil { + c.closeReadMu.Unlock() + return ctx2 + } ctx, cancel := context.WithCancel(ctx) - c.wg.Add(1) + c.closeReadCtx = ctx + c.closeReadMu.Unlock() + go func() { - defer c.CloseNow() - defer c.wg.Done() defer cancel() + defer c.CloseNow() _, _, err := c.read(ctx) if err != nil { c.Close(StatusPolicyViolation, "unexpected data message") diff --git a/ws_js_test.go b/ws_js_test.go index ba98b9a0..b56ad16b 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/assert" - "nhooyr.io/websocket/internal/test/wstest" + "github.com/coder/websocket" + "github.com/coder/websocket/internal/test/assert" + "github.com/coder/websocket/internal/test/wstest" ) func TestWasm(t *testing.T) { diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 7c986a0d..05e7cfa1 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -1,15 +1,15 @@ // Package wsjson provides helpers for reading and writing JSON messages. -package wsjson // import "nhooyr.io/websocket/wsjson" +package wsjson // import "github.com/coder/websocket/wsjson" import ( "context" "encoding/json" "fmt" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bpool" - "nhooyr.io/websocket/internal/errd" - "nhooyr.io/websocket/internal/util" + "github.com/coder/websocket" + "github.com/coder/websocket/internal/bpool" + "github.com/coder/websocket/internal/errd" + "github.com/coder/websocket/internal/util" ) // Read reads a JSON message from c into v. diff --git a/wsjson/wsjson_test.go b/wsjson/wsjson_test.go new file mode 100644 index 00000000..87a72854 --- /dev/null +++ b/wsjson/wsjson_test.go @@ -0,0 +1,53 @@ +package wsjson_test + +import ( + "encoding/json" + "io" + "strconv" + "testing" + + "github.com/coder/websocket/internal/test/xrand" +) + +func BenchmarkJSON(b *testing.B) { + sizes := []int{ + 8, + 16, + 32, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384, + } + + b.Run("json.Encoder", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + msg := xrand.String(size) + b.SetBytes(int64(size)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + json.NewEncoder(io.Discard).Encode(msg) + } + }) + } + }) + b.Run("json.Marshal", func(b *testing.B) { + for _, size := range sizes { + b.Run(strconv.Itoa(size), func(b *testing.B) { + msg := xrand.String(size) + b.SetBytes(int64(size)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + json.Marshal(msg) + } + }) + } + }) +}