diff --git a/LICENSE b/LICENSE
index 6a66aea5ea..2a7cf70da6 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,4 +1,4 @@
-Copyright (c) 2009 The Go Authors. All rights reserved.
+Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
- * Neither the name of Google Inc. nor the names of its
+ * Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
diff --git a/README.md b/README.md
index a15f253dff..235c8d8b30 100644
--- a/README.md
+++ b/README.md
@@ -2,17 +2,15 @@
[](https://pkg.go.dev/golang.org/x/net)
-This repository holds supplementary Go networking libraries.
+This repository holds supplementary Go networking packages.
-## Download/Install
+## Report Issues / Send Patches
-The easiest way to install is to run `go get -u golang.org/x/net`. You can
-also manually git clone the repository to `$GOPATH/src/golang.org/x/net`.
+This repository uses Gerrit for code changes. To learn how to submit changes to
+this repository, see https://go.dev/doc/contribute.
-## Report Issues / Send Patches
+The git repository is https://go.googlesource.com/net.
-This repository uses Gerrit for code changes. To learn how to submit
-changes to this repository, see https://golang.org/doc/contribute.html.
The main issue tracker for the net repository is located at
-https://github.com/golang/go/issues. Prefix your issue with "x/net:" in the
+https://go.dev/issues. Prefix your issue with "x/net:" in the
subject line, so it is easy to find.
diff --git a/context/context.go b/context/context.go
index cf66309c4a..db1c95fab1 100644
--- a/context/context.go
+++ b/context/context.go
@@ -3,29 +3,31 @@
// license that can be found in the LICENSE file.
// Package context defines the Context type, which carries deadlines,
-// cancelation signals, and other request-scoped values across API boundaries
+// cancellation signals, and other request-scoped values across API boundaries
// and between processes.
// As of Go 1.7 this package is available in the standard library under the
-// name context. https://golang.org/pkg/context.
+// name [context], and migrating to it can be done automatically with [go fix].
//
-// Incoming requests to a server should create a Context, and outgoing calls to
-// servers should accept a Context. The chain of function calls between must
-// propagate the Context, optionally replacing it with a modified copy created
-// using WithDeadline, WithTimeout, WithCancel, or WithValue.
+// Incoming requests to a server should create a [Context], and outgoing
+// calls to servers should accept a Context. The chain of function
+// calls between them must propagate the Context, optionally replacing
+// it with a derived Context created using [WithCancel], [WithDeadline],
+// [WithTimeout], or [WithValue].
//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
//
// Do not store Contexts inside a struct type; instead, pass a Context
-// explicitly to each function that needs it. The Context should be the first
+// explicitly to each function that needs it. This is discussed further in
+// https://go.dev/blog/context-and-structs. The Context should be the first
// parameter, typically named ctx:
//
// func DoSomething(ctx context.Context, arg Arg) error {
// // ... use ctx ...
// }
//
-// Do not pass a nil Context, even if a function permits it. Pass context.TODO
+// Do not pass a nil [Context], even if a function permits it. Pass [context.TODO]
// if you are unsure about which Context to use.
//
// Use context Values only for request-scoped data that transits processes and
@@ -34,9 +36,30 @@
// The same Context may be passed to functions running in different goroutines;
// Contexts are safe for simultaneous use by multiple goroutines.
//
-// See http://blog.golang.org/context for example code for a server that uses
+// See https://go.dev/blog/context for example code for a server that uses
// Contexts.
-package context // import "golang.org/x/net/context"
+//
+// [go fix]: https://go.dev/cmd/go#hdr-Update_packages_to_use_new_APIs
+package context
+
+import (
+ "context" // standard library's context, as of Go 1.7
+ "time"
+)
+
+// A Context carries a deadline, a cancellation signal, and other values across
+// API boundaries.
+//
+// Context's methods may be called by multiple goroutines simultaneously.
+type Context = context.Context
+
+// Canceled is the error returned by [Context.Err] when the context is canceled
+// for some reason other than its deadline passing.
+var Canceled = context.Canceled
+
+// DeadlineExceeded is the error returned by [Context.Err] when the context is canceled
+// due to its deadline passing.
+var DeadlineExceeded = context.DeadlineExceeded
// Background returns a non-nil, empty Context. It is never canceled, has no
// values, and has no deadline. It is typically used by the main function,
@@ -49,8 +72,73 @@ func Background() Context {
// TODO returns a non-nil, empty Context. Code should use context.TODO when
// it's unclear which Context to use or it is not yet available (because the
// surrounding function has not yet been extended to accept a Context
-// parameter). TODO is recognized by static analysis tools that determine
-// whether Contexts are propagated correctly in a program.
+// parameter).
func TODO() Context {
return todo
}
+
+var (
+ background = context.Background()
+ todo = context.TODO()
+)
+
+// A CancelFunc tells an operation to abandon its work.
+// A CancelFunc does not wait for the work to stop.
+// A CancelFunc may be called by multiple goroutines simultaneously.
+// After the first call, subsequent calls to a CancelFunc do nothing.
+type CancelFunc = context.CancelFunc
+
+// WithCancel returns a derived context that points to the parent context
+// but has a new Done channel. The returned context's Done channel is closed
+// when the returned cancel function is called or when the parent context's
+// Done channel is closed, whichever happens first.
+//
+// Canceling this context releases resources associated with it, so code should
+// call cancel as soon as the operations running in this [Context] complete.
+func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
+ return context.WithCancel(parent)
+}
+
+// WithDeadline returns a derived context that points to the parent context
+// but has the deadline adjusted to be no later than d. If the parent's
+// deadline is already earlier than d, WithDeadline(parent, d) is semantically
+// equivalent to parent. The returned [Context.Done] channel is closed when
+// the deadline expires, when the returned cancel function is called,
+// or when the parent context's Done channel is closed, whichever happens first.
+//
+// Canceling this context releases resources associated with it, so code should
+// call cancel as soon as the operations running in this [Context] complete.
+func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
+ return context.WithDeadline(parent, d)
+}
+
+// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
+//
+// Canceling this context releases resources associated with it, so code should
+// call cancel as soon as the operations running in this [Context] complete:
+//
+// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
+// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
+// defer cancel() // releases resources if slowOperation completes before timeout elapses
+// return slowOperation(ctx)
+// }
+func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
+ return context.WithTimeout(parent, timeout)
+}
+
+// WithValue returns a derived context that points to the parent Context.
+// In the derived context, the value associated with key is val.
+//
+// Use context Values only for request-scoped data that transits processes and
+// APIs, not for passing optional parameters to functions.
+//
+// The provided key must be comparable and should not be of type
+// string or any other built-in type to avoid collisions between
+// packages using context. Users of WithValue should define their own
+// types for keys. To avoid allocating when assigning to an
+// interface{}, context keys often have concrete type
+// struct{}. Alternatively, exported context key variables' static
+// type should be a pointer or interface.
+func WithValue(parent Context, key, val interface{}) Context {
+ return context.WithValue(parent, key, val)
+}
diff --git a/context/context_test.go b/context/context_test.go
deleted file mode 100644
index 2cb54edb89..0000000000
--- a/context/context_test.go
+++ /dev/null
@@ -1,583 +0,0 @@
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.7
-
-package context
-
-import (
- "fmt"
- "math/rand"
- "runtime"
- "strings"
- "sync"
- "testing"
- "time"
-)
-
-// otherContext is a Context that's not one of the types defined in context.go.
-// This lets us test code paths that differ based on the underlying type of the
-// Context.
-type otherContext struct {
- Context
-}
-
-func TestBackground(t *testing.T) {
- c := Background()
- if c == nil {
- t.Fatalf("Background returned nil")
- }
- select {
- case x := <-c.Done():
- t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
- default:
- }
- if got, want := fmt.Sprint(c), "context.Background"; got != want {
- t.Errorf("Background().String() = %q want %q", got, want)
- }
-}
-
-func TestTODO(t *testing.T) {
- c := TODO()
- if c == nil {
- t.Fatalf("TODO returned nil")
- }
- select {
- case x := <-c.Done():
- t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
- default:
- }
- if got, want := fmt.Sprint(c), "context.TODO"; got != want {
- t.Errorf("TODO().String() = %q want %q", got, want)
- }
-}
-
-func TestWithCancel(t *testing.T) {
- c1, cancel := WithCancel(Background())
-
- if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want {
- t.Errorf("c1.String() = %q want %q", got, want)
- }
-
- o := otherContext{c1}
- c2, _ := WithCancel(o)
- contexts := []Context{c1, o, c2}
-
- for i, c := range contexts {
- if d := c.Done(); d == nil {
- t.Errorf("c[%d].Done() == %v want non-nil", i, d)
- }
- if e := c.Err(); e != nil {
- t.Errorf("c[%d].Err() == %v want nil", i, e)
- }
-
- select {
- case x := <-c.Done():
- t.Errorf("<-c.Done() == %v want nothing (it should block)", x)
- default:
- }
- }
-
- cancel()
- time.Sleep(100 * time.Millisecond) // let cancelation propagate
-
- for i, c := range contexts {
- select {
- case <-c.Done():
- default:
- t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i)
- }
- if e := c.Err(); e != Canceled {
- t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled)
- }
- }
-}
-
-func TestParentFinishesChild(t *testing.T) {
- // Context tree:
- // parent -> cancelChild
- // parent -> valueChild -> timerChild
- parent, cancel := WithCancel(Background())
- cancelChild, stop := WithCancel(parent)
- defer stop()
- valueChild := WithValue(parent, "key", "value")
- timerChild, stop := WithTimeout(valueChild, 10000*time.Hour)
- defer stop()
-
- select {
- case x := <-parent.Done():
- t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
- case x := <-cancelChild.Done():
- t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x)
- case x := <-timerChild.Done():
- t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x)
- case x := <-valueChild.Done():
- t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x)
- default:
- }
-
- // The parent's children should contain the two cancelable children.
- pc := parent.(*cancelCtx)
- cc := cancelChild.(*cancelCtx)
- tc := timerChild.(*timerCtx)
- pc.mu.Lock()
- if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] {
- t.Errorf("bad linkage: pc.children = %v, want %v and %v",
- pc.children, cc, tc)
- }
- pc.mu.Unlock()
-
- if p, ok := parentCancelCtx(cc.Context); !ok || p != pc {
- t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc)
- }
- if p, ok := parentCancelCtx(tc.Context); !ok || p != pc {
- t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc)
- }
-
- cancel()
-
- pc.mu.Lock()
- if len(pc.children) != 0 {
- t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children)
- }
- pc.mu.Unlock()
-
- // parent and children should all be finished.
- check := func(ctx Context, name string) {
- select {
- case <-ctx.Done():
- default:
- t.Errorf("<-%s.Done() blocked, but shouldn't have", name)
- }
- if e := ctx.Err(); e != Canceled {
- t.Errorf("%s.Err() == %v want %v", name, e, Canceled)
- }
- }
- check(parent, "parent")
- check(cancelChild, "cancelChild")
- check(valueChild, "valueChild")
- check(timerChild, "timerChild")
-
- // WithCancel should return a canceled context on a canceled parent.
- precanceledChild := WithValue(parent, "key", "value")
- select {
- case <-precanceledChild.Done():
- default:
- t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have")
- }
- if e := precanceledChild.Err(); e != Canceled {
- t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled)
- }
-}
-
-func TestChildFinishesFirst(t *testing.T) {
- cancelable, stop := WithCancel(Background())
- defer stop()
- for _, parent := range []Context{Background(), cancelable} {
- child, cancel := WithCancel(parent)
-
- select {
- case x := <-parent.Done():
- t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
- case x := <-child.Done():
- t.Errorf("<-child.Done() == %v want nothing (it should block)", x)
- default:
- }
-
- cc := child.(*cancelCtx)
- pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background()
- if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) {
- t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok)
- }
-
- if pcok {
- pc.mu.Lock()
- if len(pc.children) != 1 || !pc.children[cc] {
- t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc)
- }
- pc.mu.Unlock()
- }
-
- cancel()
-
- if pcok {
- pc.mu.Lock()
- if len(pc.children) != 0 {
- t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children)
- }
- pc.mu.Unlock()
- }
-
- // child should be finished.
- select {
- case <-child.Done():
- default:
- t.Errorf("<-child.Done() blocked, but shouldn't have")
- }
- if e := child.Err(); e != Canceled {
- t.Errorf("child.Err() == %v want %v", e, Canceled)
- }
-
- // parent should not be finished.
- select {
- case x := <-parent.Done():
- t.Errorf("<-parent.Done() == %v want nothing (it should block)", x)
- default:
- }
- if e := parent.Err(); e != nil {
- t.Errorf("parent.Err() == %v want nil", e)
- }
- }
-}
-
-func testDeadline(c Context, wait time.Duration, t *testing.T) {
- select {
- case <-time.After(wait):
- t.Fatalf("context should have timed out")
- case <-c.Done():
- }
- if e := c.Err(); e != DeadlineExceeded {
- t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded)
- }
-}
-
-func TestDeadline(t *testing.T) {
- t.Parallel()
- const timeUnit = 500 * time.Millisecond
- c, _ := WithDeadline(Background(), time.Now().Add(1*timeUnit))
- if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
- t.Errorf("c.String() = %q want prefix %q", got, prefix)
- }
- testDeadline(c, 2*timeUnit, t)
-
- c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit))
- o := otherContext{c}
- testDeadline(o, 2*timeUnit, t)
-
- c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit))
- o = otherContext{c}
- c, _ = WithDeadline(o, time.Now().Add(3*timeUnit))
- testDeadline(c, 2*timeUnit, t)
-}
-
-func TestTimeout(t *testing.T) {
- t.Parallel()
- const timeUnit = 500 * time.Millisecond
- c, _ := WithTimeout(Background(), 1*timeUnit)
- if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) {
- t.Errorf("c.String() = %q want prefix %q", got, prefix)
- }
- testDeadline(c, 2*timeUnit, t)
-
- c, _ = WithTimeout(Background(), 1*timeUnit)
- o := otherContext{c}
- testDeadline(o, 2*timeUnit, t)
-
- c, _ = WithTimeout(Background(), 1*timeUnit)
- o = otherContext{c}
- c, _ = WithTimeout(o, 3*timeUnit)
- testDeadline(c, 2*timeUnit, t)
-}
-
-func TestCanceledTimeout(t *testing.T) {
- t.Parallel()
- const timeUnit = 500 * time.Millisecond
- c, _ := WithTimeout(Background(), 2*timeUnit)
- o := otherContext{c}
- c, cancel := WithTimeout(o, 4*timeUnit)
- cancel()
- time.Sleep(1 * timeUnit) // let cancelation propagate
- select {
- case <-c.Done():
- default:
- t.Errorf("<-c.Done() blocked, but shouldn't have")
- }
- if e := c.Err(); e != Canceled {
- t.Errorf("c.Err() == %v want %v", e, Canceled)
- }
-}
-
-type key1 int
-type key2 int
-
-var k1 = key1(1)
-var k2 = key2(1) // same int as k1, different type
-var k3 = key2(3) // same type as k2, different int
-
-func TestValues(t *testing.T) {
- check := func(c Context, nm, v1, v2, v3 string) {
- if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 {
- t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0)
- }
- if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 {
- t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0)
- }
- if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 {
- t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0)
- }
- }
-
- c0 := Background()
- check(c0, "c0", "", "", "")
-
- c1 := WithValue(Background(), k1, "c1k1")
- check(c1, "c1", "c1k1", "", "")
-
- if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want {
- t.Errorf("c.String() = %q want %q", got, want)
- }
-
- c2 := WithValue(c1, k2, "c2k2")
- check(c2, "c2", "c1k1", "c2k2", "")
-
- c3 := WithValue(c2, k3, "c3k3")
- check(c3, "c2", "c1k1", "c2k2", "c3k3")
-
- c4 := WithValue(c3, k1, nil)
- check(c4, "c4", "", "c2k2", "c3k3")
-
- o0 := otherContext{Background()}
- check(o0, "o0", "", "", "")
-
- o1 := otherContext{WithValue(Background(), k1, "c1k1")}
- check(o1, "o1", "c1k1", "", "")
-
- o2 := WithValue(o1, k2, "o2k2")
- check(o2, "o2", "c1k1", "o2k2", "")
-
- o3 := otherContext{c4}
- check(o3, "o3", "", "c2k2", "c3k3")
-
- o4 := WithValue(o3, k3, nil)
- check(o4, "o4", "", "c2k2", "")
-}
-
-func TestAllocs(t *testing.T) {
- bg := Background()
- for _, test := range []struct {
- desc string
- f func()
- limit float64
- gccgoLimit float64
- }{
- {
- desc: "Background()",
- f: func() { Background() },
- limit: 0,
- gccgoLimit: 0,
- },
- {
- desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1),
- f: func() {
- c := WithValue(bg, k1, nil)
- c.Value(k1)
- },
- limit: 3,
- gccgoLimit: 3,
- },
- {
- desc: "WithTimeout(bg, 15*time.Millisecond)",
- f: func() {
- c, _ := WithTimeout(bg, 15*time.Millisecond)
- <-c.Done()
- },
- limit: 8,
- gccgoLimit: 16,
- },
- {
- desc: "WithCancel(bg)",
- f: func() {
- c, cancel := WithCancel(bg)
- cancel()
- <-c.Done()
- },
- limit: 5,
- gccgoLimit: 8,
- },
- {
- desc: "WithTimeout(bg, 100*time.Millisecond)",
- f: func() {
- c, cancel := WithTimeout(bg, 100*time.Millisecond)
- cancel()
- <-c.Done()
- },
- limit: 8,
- gccgoLimit: 25,
- },
- } {
- limit := test.limit
- if runtime.Compiler == "gccgo" {
- // gccgo does not yet do escape analysis.
- // TODO(iant): Remove this when gccgo does do escape analysis.
- limit = test.gccgoLimit
- }
- if n := testing.AllocsPerRun(100, test.f); n > limit {
- t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit))
- }
- }
-}
-
-func TestSimultaneousCancels(t *testing.T) {
- root, cancel := WithCancel(Background())
- m := map[Context]CancelFunc{root: cancel}
- q := []Context{root}
- // Create a tree of contexts.
- for len(q) != 0 && len(m) < 100 {
- parent := q[0]
- q = q[1:]
- for i := 0; i < 4; i++ {
- ctx, cancel := WithCancel(parent)
- m[ctx] = cancel
- q = append(q, ctx)
- }
- }
- // Start all the cancels in a random order.
- var wg sync.WaitGroup
- wg.Add(len(m))
- for _, cancel := range m {
- go func(cancel CancelFunc) {
- cancel()
- wg.Done()
- }(cancel)
- }
- // Wait on all the contexts in a random order.
- for ctx := range m {
- select {
- case <-ctx.Done():
- case <-time.After(1 * time.Second):
- buf := make([]byte, 10<<10)
- n := runtime.Stack(buf, true)
- t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n])
- }
- }
- // Wait for all the cancel functions to return.
- done := make(chan struct{})
- go func() {
- wg.Wait()
- close(done)
- }()
- select {
- case <-done:
- case <-time.After(1 * time.Second):
- buf := make([]byte, 10<<10)
- n := runtime.Stack(buf, true)
- t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n])
- }
-}
-
-func TestInterlockedCancels(t *testing.T) {
- parent, cancelParent := WithCancel(Background())
- child, cancelChild := WithCancel(parent)
- go func() {
- parent.Done()
- cancelChild()
- }()
- cancelParent()
- select {
- case <-child.Done():
- case <-time.After(1 * time.Second):
- buf := make([]byte, 10<<10)
- n := runtime.Stack(buf, true)
- t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n])
- }
-}
-
-func TestLayersCancel(t *testing.T) {
- testLayers(t, time.Now().UnixNano(), false)
-}
-
-func TestLayersTimeout(t *testing.T) {
- testLayers(t, time.Now().UnixNano(), true)
-}
-
-func testLayers(t *testing.T, seed int64, testTimeout bool) {
- rand.Seed(seed)
- errorf := func(format string, a ...interface{}) {
- t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...)
- }
- const (
- timeout = 200 * time.Millisecond
- minLayers = 30
- )
- type value int
- var (
- vals []*value
- cancels []CancelFunc
- numTimers int
- ctx = Background()
- )
- for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ {
- switch rand.Intn(3) {
- case 0:
- v := new(value)
- ctx = WithValue(ctx, v, v)
- vals = append(vals, v)
- case 1:
- var cancel CancelFunc
- ctx, cancel = WithCancel(ctx)
- cancels = append(cancels, cancel)
- case 2:
- var cancel CancelFunc
- ctx, cancel = WithTimeout(ctx, timeout)
- cancels = append(cancels, cancel)
- numTimers++
- }
- }
- checkValues := func(when string) {
- for _, key := range vals {
- if val := ctx.Value(key).(*value); key != val {
- errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key)
- }
- }
- }
- select {
- case <-ctx.Done():
- errorf("ctx should not be canceled yet")
- default:
- }
- if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) {
- t.Errorf("ctx.String() = %q want prefix %q", s, prefix)
- }
- t.Log(ctx)
- checkValues("before cancel")
- if testTimeout {
- select {
- case <-ctx.Done():
- case <-time.After(timeout + 100*time.Millisecond):
- errorf("ctx should have timed out")
- }
- checkValues("after timeout")
- } else {
- cancel := cancels[rand.Intn(len(cancels))]
- cancel()
- select {
- case <-ctx.Done():
- default:
- errorf("ctx should be canceled")
- }
- checkValues("after cancel")
- }
-}
-
-func TestCancelRemoves(t *testing.T) {
- checkChildren := func(when string, ctx Context, want int) {
- if got := len(ctx.(*cancelCtx).children); got != want {
- t.Errorf("%s: context has %d children, want %d", when, got, want)
- }
- }
-
- ctx, _ := WithCancel(Background())
- checkChildren("after creation", ctx, 0)
- _, cancel := WithCancel(ctx)
- checkChildren("with WithCancel child ", ctx, 1)
- cancel()
- checkChildren("after cancelling WithCancel child", ctx, 0)
-
- ctx, _ = WithCancel(Background())
- checkChildren("after creation", ctx, 0)
- _, cancel = WithTimeout(ctx, 60*time.Minute)
- checkChildren("with WithTimeout child ", ctx, 1)
- cancel()
- checkChildren("after cancelling WithTimeout child", ctx, 0)
-}
diff --git a/context/ctxhttp/ctxhttp.go b/context/ctxhttp/ctxhttp.go
index 37dc0cfdb5..e0df203cea 100644
--- a/context/ctxhttp/ctxhttp.go
+++ b/context/ctxhttp/ctxhttp.go
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// Package ctxhttp provides helper functions for performing context-aware HTTP requests.
-package ctxhttp // import "golang.org/x/net/context/ctxhttp"
+package ctxhttp
import (
"context"
diff --git a/context/go17.go b/context/go17.go
deleted file mode 100644
index 0c1b867937..0000000000
--- a/context/go17.go
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2016 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.7
-
-package context
-
-import (
- "context" // standard library's context, as of Go 1.7
- "time"
-)
-
-var (
- todo = context.TODO()
- background = context.Background()
-)
-
-// Canceled is the error returned by Context.Err when the context is canceled.
-var Canceled = context.Canceled
-
-// DeadlineExceeded is the error returned by Context.Err when the context's
-// deadline passes.
-var DeadlineExceeded = context.DeadlineExceeded
-
-// WithCancel returns a copy of parent with a new Done channel. The returned
-// context's Done channel is closed when the returned cancel function is called
-// or when the parent context's Done channel is closed, whichever happens first.
-//
-// Canceling this context releases resources associated with it, so code should
-// call cancel as soon as the operations running in this Context complete.
-func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
- ctx, f := context.WithCancel(parent)
- return ctx, f
-}
-
-// WithDeadline returns a copy of the parent context with the deadline adjusted
-// to be no later than d. If the parent's deadline is already earlier than d,
-// WithDeadline(parent, d) is semantically equivalent to parent. The returned
-// context's Done channel is closed when the deadline expires, when the returned
-// cancel function is called, or when the parent context's Done channel is
-// closed, whichever happens first.
-//
-// Canceling this context releases resources associated with it, so code should
-// call cancel as soon as the operations running in this Context complete.
-func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
- ctx, f := context.WithDeadline(parent, deadline)
- return ctx, f
-}
-
-// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
-//
-// Canceling this context releases resources associated with it, so code should
-// call cancel as soon as the operations running in this Context complete:
-//
-// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
-// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
-// defer cancel() // releases resources if slowOperation completes before timeout elapses
-// return slowOperation(ctx)
-// }
-func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
- return WithDeadline(parent, time.Now().Add(timeout))
-}
-
-// WithValue returns a copy of parent in which the value associated with key is
-// val.
-//
-// Use context Values only for request-scoped data that transits processes and
-// APIs, not for passing optional parameters to functions.
-func WithValue(parent Context, key interface{}, val interface{}) Context {
- return context.WithValue(parent, key, val)
-}
diff --git a/context/go19.go b/context/go19.go
deleted file mode 100644
index e31e35a904..0000000000
--- a/context/go19.go
+++ /dev/null
@@ -1,20 +0,0 @@
-// Copyright 2017 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build go1.9
-
-package context
-
-import "context" // standard library's context, as of Go 1.7
-
-// A Context carries a deadline, a cancelation signal, and other values across
-// API boundaries.
-//
-// Context's methods may be called by multiple goroutines simultaneously.
-type Context = context.Context
-
-// A CancelFunc tells an operation to abandon its work.
-// A CancelFunc does not wait for the work to stop.
-// After the first call, subsequent calls to a CancelFunc do nothing.
-type CancelFunc = context.CancelFunc
diff --git a/context/pre_go17.go b/context/pre_go17.go
deleted file mode 100644
index 065ff3dfa5..0000000000
--- a/context/pre_go17.go
+++ /dev/null
@@ -1,300 +0,0 @@
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.7
-
-package context
-
-import (
- "errors"
- "fmt"
- "sync"
- "time"
-)
-
-// An emptyCtx is never canceled, has no values, and has no deadline. It is not
-// struct{}, since vars of this type must have distinct addresses.
-type emptyCtx int
-
-func (*emptyCtx) Deadline() (deadline time.Time, ok bool) {
- return
-}
-
-func (*emptyCtx) Done() <-chan struct{} {
- return nil
-}
-
-func (*emptyCtx) Err() error {
- return nil
-}
-
-func (*emptyCtx) Value(key interface{}) interface{} {
- return nil
-}
-
-func (e *emptyCtx) String() string {
- switch e {
- case background:
- return "context.Background"
- case todo:
- return "context.TODO"
- }
- return "unknown empty Context"
-}
-
-var (
- background = new(emptyCtx)
- todo = new(emptyCtx)
-)
-
-// Canceled is the error returned by Context.Err when the context is canceled.
-var Canceled = errors.New("context canceled")
-
-// DeadlineExceeded is the error returned by Context.Err when the context's
-// deadline passes.
-var DeadlineExceeded = errors.New("context deadline exceeded")
-
-// WithCancel returns a copy of parent with a new Done channel. The returned
-// context's Done channel is closed when the returned cancel function is called
-// or when the parent context's Done channel is closed, whichever happens first.
-//
-// Canceling this context releases resources associated with it, so code should
-// call cancel as soon as the operations running in this Context complete.
-func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
- c := newCancelCtx(parent)
- propagateCancel(parent, c)
- return c, func() { c.cancel(true, Canceled) }
-}
-
-// newCancelCtx returns an initialized cancelCtx.
-func newCancelCtx(parent Context) *cancelCtx {
- return &cancelCtx{
- Context: parent,
- done: make(chan struct{}),
- }
-}
-
-// propagateCancel arranges for child to be canceled when parent is.
-func propagateCancel(parent Context, child canceler) {
- if parent.Done() == nil {
- return // parent is never canceled
- }
- if p, ok := parentCancelCtx(parent); ok {
- p.mu.Lock()
- if p.err != nil {
- // parent has already been canceled
- child.cancel(false, p.err)
- } else {
- if p.children == nil {
- p.children = make(map[canceler]bool)
- }
- p.children[child] = true
- }
- p.mu.Unlock()
- } else {
- go func() {
- select {
- case <-parent.Done():
- child.cancel(false, parent.Err())
- case <-child.Done():
- }
- }()
- }
-}
-
-// parentCancelCtx follows a chain of parent references until it finds a
-// *cancelCtx. This function understands how each of the concrete types in this
-// package represents its parent.
-func parentCancelCtx(parent Context) (*cancelCtx, bool) {
- for {
- switch c := parent.(type) {
- case *cancelCtx:
- return c, true
- case *timerCtx:
- return c.cancelCtx, true
- case *valueCtx:
- parent = c.Context
- default:
- return nil, false
- }
- }
-}
-
-// removeChild removes a context from its parent.
-func removeChild(parent Context, child canceler) {
- p, ok := parentCancelCtx(parent)
- if !ok {
- return
- }
- p.mu.Lock()
- if p.children != nil {
- delete(p.children, child)
- }
- p.mu.Unlock()
-}
-
-// A canceler is a context type that can be canceled directly. The
-// implementations are *cancelCtx and *timerCtx.
-type canceler interface {
- cancel(removeFromParent bool, err error)
- Done() <-chan struct{}
-}
-
-// A cancelCtx can be canceled. When canceled, it also cancels any children
-// that implement canceler.
-type cancelCtx struct {
- Context
-
- done chan struct{} // closed by the first cancel call.
-
- mu sync.Mutex
- children map[canceler]bool // set to nil by the first cancel call
- err error // set to non-nil by the first cancel call
-}
-
-func (c *cancelCtx) Done() <-chan struct{} {
- return c.done
-}
-
-func (c *cancelCtx) Err() error {
- c.mu.Lock()
- defer c.mu.Unlock()
- return c.err
-}
-
-func (c *cancelCtx) String() string {
- return fmt.Sprintf("%v.WithCancel", c.Context)
-}
-
-// cancel closes c.done, cancels each of c's children, and, if
-// removeFromParent is true, removes c from its parent's children.
-func (c *cancelCtx) cancel(removeFromParent bool, err error) {
- if err == nil {
- panic("context: internal error: missing cancel error")
- }
- c.mu.Lock()
- if c.err != nil {
- c.mu.Unlock()
- return // already canceled
- }
- c.err = err
- close(c.done)
- for child := range c.children {
- // NOTE: acquiring the child's lock while holding parent's lock.
- child.cancel(false, err)
- }
- c.children = nil
- c.mu.Unlock()
-
- if removeFromParent {
- removeChild(c.Context, c)
- }
-}
-
-// WithDeadline returns a copy of the parent context with the deadline adjusted
-// to be no later than d. If the parent's deadline is already earlier than d,
-// WithDeadline(parent, d) is semantically equivalent to parent. The returned
-// context's Done channel is closed when the deadline expires, when the returned
-// cancel function is called, or when the parent context's Done channel is
-// closed, whichever happens first.
-//
-// Canceling this context releases resources associated with it, so code should
-// call cancel as soon as the operations running in this Context complete.
-func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) {
- if cur, ok := parent.Deadline(); ok && cur.Before(deadline) {
- // The current deadline is already sooner than the new one.
- return WithCancel(parent)
- }
- c := &timerCtx{
- cancelCtx: newCancelCtx(parent),
- deadline: deadline,
- }
- propagateCancel(parent, c)
- d := deadline.Sub(time.Now())
- if d <= 0 {
- c.cancel(true, DeadlineExceeded) // deadline has already passed
- return c, func() { c.cancel(true, Canceled) }
- }
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.err == nil {
- c.timer = time.AfterFunc(d, func() {
- c.cancel(true, DeadlineExceeded)
- })
- }
- return c, func() { c.cancel(true, Canceled) }
-}
-
-// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
-// implement Done and Err. It implements cancel by stopping its timer then
-// delegating to cancelCtx.cancel.
-type timerCtx struct {
- *cancelCtx
- timer *time.Timer // Under cancelCtx.mu.
-
- deadline time.Time
-}
-
-func (c *timerCtx) Deadline() (deadline time.Time, ok bool) {
- return c.deadline, true
-}
-
-func (c *timerCtx) String() string {
- return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now()))
-}
-
-func (c *timerCtx) cancel(removeFromParent bool, err error) {
- c.cancelCtx.cancel(false, err)
- if removeFromParent {
- // Remove this timerCtx from its parent cancelCtx's children.
- removeChild(c.cancelCtx.Context, c)
- }
- c.mu.Lock()
- if c.timer != nil {
- c.timer.Stop()
- c.timer = nil
- }
- c.mu.Unlock()
-}
-
-// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)).
-//
-// Canceling this context releases resources associated with it, so code should
-// call cancel as soon as the operations running in this Context complete:
-//
-// func slowOperationWithTimeout(ctx context.Context) (Result, error) {
-// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
-// defer cancel() // releases resources if slowOperation completes before timeout elapses
-// return slowOperation(ctx)
-// }
-func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
- return WithDeadline(parent, time.Now().Add(timeout))
-}
-
-// WithValue returns a copy of parent in which the value associated with key is
-// val.
-//
-// Use context Values only for request-scoped data that transits processes and
-// APIs, not for passing optional parameters to functions.
-func WithValue(parent Context, key interface{}, val interface{}) Context {
- return &valueCtx{parent, key, val}
-}
-
-// A valueCtx carries a key-value pair. It implements Value for that key and
-// delegates all other calls to the embedded Context.
-type valueCtx struct {
- Context
- key, val interface{}
-}
-
-func (c *valueCtx) String() string {
- return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val)
-}
-
-func (c *valueCtx) Value(key interface{}) interface{} {
- if c.key == key {
- return c.val
- }
- return c.Context.Value(key)
-}
diff --git a/context/pre_go19.go b/context/pre_go19.go
deleted file mode 100644
index ec5a638033..0000000000
--- a/context/pre_go19.go
+++ /dev/null
@@ -1,109 +0,0 @@
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build !go1.9
-
-package context
-
-import "time"
-
-// A Context carries a deadline, a cancelation signal, and other values across
-// API boundaries.
-//
-// Context's methods may be called by multiple goroutines simultaneously.
-type Context interface {
- // Deadline returns the time when work done on behalf of this context
- // should be canceled. Deadline returns ok==false when no deadline is
- // set. Successive calls to Deadline return the same results.
- Deadline() (deadline time.Time, ok bool)
-
- // Done returns a channel that's closed when work done on behalf of this
- // context should be canceled. Done may return nil if this context can
- // never be canceled. Successive calls to Done return the same value.
- //
- // WithCancel arranges for Done to be closed when cancel is called;
- // WithDeadline arranges for Done to be closed when the deadline
- // expires; WithTimeout arranges for Done to be closed when the timeout
- // elapses.
- //
- // Done is provided for use in select statements:
- //
- // // Stream generates values with DoSomething and sends them to out
- // // until DoSomething returns an error or ctx.Done is closed.
- // func Stream(ctx context.Context, out chan<- Value) error {
- // for {
- // v, err := DoSomething(ctx)
- // if err != nil {
- // return err
- // }
- // select {
- // case <-ctx.Done():
- // return ctx.Err()
- // case out <- v:
- // }
- // }
- // }
- //
- // See http://blog.golang.org/pipelines for more examples of how to use
- // a Done channel for cancelation.
- Done() <-chan struct{}
-
- // Err returns a non-nil error value after Done is closed. Err returns
- // Canceled if the context was canceled or DeadlineExceeded if the
- // context's deadline passed. No other values for Err are defined.
- // After Done is closed, successive calls to Err return the same value.
- Err() error
-
- // Value returns the value associated with this context for key, or nil
- // if no value is associated with key. Successive calls to Value with
- // the same key returns the same result.
- //
- // Use context values only for request-scoped data that transits
- // processes and API boundaries, not for passing optional parameters to
- // functions.
- //
- // A key identifies a specific value in a Context. Functions that wish
- // to store values in Context typically allocate a key in a global
- // variable then use that key as the argument to context.WithValue and
- // Context.Value. A key can be any type that supports equality;
- // packages should define keys as an unexported type to avoid
- // collisions.
- //
- // Packages that define a Context key should provide type-safe accessors
- // for the values stores using that key:
- //
- // // Package user defines a User type that's stored in Contexts.
- // package user
- //
- // import "golang.org/x/net/context"
- //
- // // User is the type of value stored in the Contexts.
- // type User struct {...}
- //
- // // key is an unexported type for keys defined in this package.
- // // This prevents collisions with keys defined in other packages.
- // type key int
- //
- // // userKey is the key for user.User values in Contexts. It is
- // // unexported; clients use user.NewContext and user.FromContext
- // // instead of using this key directly.
- // var userKey key = 0
- //
- // // NewContext returns a new Context that carries value u.
- // func NewContext(ctx context.Context, u *User) context.Context {
- // return context.WithValue(ctx, userKey, u)
- // }
- //
- // // FromContext returns the User value stored in ctx, if any.
- // func FromContext(ctx context.Context) (*User, bool) {
- // u, ok := ctx.Value(userKey).(*User)
- // return u, ok
- // }
- Value(key interface{}) interface{}
-}
-
-// A CancelFunc tells an operation to abandon its work.
-// A CancelFunc does not wait for the work to stop.
-// After the first call, subsequent calls to a CancelFunc do nothing.
-type CancelFunc func()
diff --git a/context/withtimeout_test.go b/context/withtimeout_test.go
deleted file mode 100644
index e6f56691d1..0000000000
--- a/context/withtimeout_test.go
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package context_test
-
-import (
- "fmt"
- "time"
-
- "golang.org/x/net/context"
-)
-
-// This example passes a context with a timeout to tell a blocking function that
-// it should abandon its work after the timeout elapses.
-func ExampleWithTimeout() {
- // Pass a context with a timeout to tell a blocking function that it
- // should abandon its work after the timeout elapses.
- ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
- defer cancel()
-
- select {
- case <-time.After(1 * time.Second):
- fmt.Println("overslept")
- case <-ctx.Done():
- fmt.Println(ctx.Err()) // prints "context deadline exceeded"
- }
-
- // Output:
- // context deadline exceeded
-}
diff --git a/go.mod b/go.mod
index ea400e5945..37aac27a62 100644
--- a/go.mod
+++ b/go.mod
@@ -1,10 +1,10 @@
module golang.org/x/net
-go 1.18
+go 1.23.0
require (
- golang.org/x/crypto v0.25.0
- golang.org/x/sys v0.22.0
- golang.org/x/term v0.22.0
- golang.org/x/text v0.16.0
+ golang.org/x/crypto v0.35.0
+ golang.org/x/sys v0.30.0
+ golang.org/x/term v0.29.0
+ golang.org/x/text v0.22.0
)
diff --git a/go.sum b/go.sum
index c3977102eb..5f95431dfa 100644
--- a/go.sum
+++ b/go.sum
@@ -1,8 +1,8 @@
-golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
-golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
-golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
-golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
-golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
-golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
-golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
+golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
+golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
+golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
+golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
+golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
+golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
+golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
diff --git a/html/doc.go b/html/doc.go
index 3a7e5ab176..885c4c5936 100644
--- a/html/doc.go
+++ b/html/doc.go
@@ -78,16 +78,11 @@ example, to process each anchor node in depth-first order:
if err != nil {
// ...
}
- var f func(*html.Node)
- f = func(n *html.Node) {
+ for n := range doc.Descendants() {
if n.Type == html.ElementNode && n.Data == "a" {
// Do something with n...
}
- for c := n.FirstChild; c != nil; c = c.NextSibling {
- f(c)
- }
}
- f(doc)
The relevant specifications include:
https://html.spec.whatwg.org/multipage/syntax.html and
diff --git a/html/doctype.go b/html/doctype.go
index c484e5a94f..bca3ae9a0c 100644
--- a/html/doctype.go
+++ b/html/doctype.go
@@ -87,7 +87,7 @@ func parseDoctype(s string) (n *Node, quirks bool) {
}
}
if lastAttr := n.Attr[len(n.Attr)-1]; lastAttr.Key == "system" &&
- strings.ToLower(lastAttr.Val) == "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd" {
+ strings.EqualFold(lastAttr.Val, "http://www.ibm.com/data/dtd/v11/ibmxhtml1-transitional.dtd") {
quirks = true
}
}
diff --git a/html/example_test.go b/html/example_test.go
index 0b06ed7730..830f0b27af 100644
--- a/html/example_test.go
+++ b/html/example_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+//go:build go1.23
+
// This example demonstrates parsing HTML data and walking the resulting tree.
package html_test
@@ -11,6 +13,7 @@ import (
"strings"
"golang.org/x/net/html"
+ "golang.org/x/net/html/atom"
)
func ExampleParse() {
@@ -19,9 +22,8 @@ func ExampleParse() {
if err != nil {
log.Fatal(err)
}
- var f func(*html.Node)
- f = func(n *html.Node) {
- if n.Type == html.ElementNode && n.Data == "a" {
+ for n := range doc.Descendants() {
+ if n.Type == html.ElementNode && n.DataAtom == atom.A {
for _, a := range n.Attr {
if a.Key == "href" {
fmt.Println(a.Val)
@@ -29,11 +31,8 @@ func ExampleParse() {
}
}
}
- for c := n.FirstChild; c != nil; c = c.NextSibling {
- f(c)
- }
}
- f(doc)
+
// Output:
// foo
// /bar/baz
diff --git a/html/foreign.go b/html/foreign.go
index 9da9e9dc42..e8515d8e88 100644
--- a/html/foreign.go
+++ b/html/foreign.go
@@ -40,8 +40,7 @@ func htmlIntegrationPoint(n *Node) bool {
if n.Data == "annotation-xml" {
for _, a := range n.Attr {
if a.Key == "encoding" {
- val := strings.ToLower(a.Val)
- if val == "text/html" || val == "application/xhtml+xml" {
+ if strings.EqualFold(a.Val, "text/html") || strings.EqualFold(a.Val, "application/xhtml+xml") {
return true
}
}
diff --git a/html/iter.go b/html/iter.go
new file mode 100644
index 0000000000..54be8fd30f
--- /dev/null
+++ b/html/iter.go
@@ -0,0 +1,56 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.23
+
+package html
+
+import "iter"
+
+// Ancestors returns an iterator over the ancestors of n, starting with n.Parent.
+//
+// Mutating a Node or its parents while iterating may have unexpected results.
+func (n *Node) Ancestors() iter.Seq[*Node] {
+ _ = n.Parent // eager nil check
+
+ return func(yield func(*Node) bool) {
+ for p := n.Parent; p != nil && yield(p); p = p.Parent {
+ }
+ }
+}
+
+// ChildNodes returns an iterator over the immediate children of n,
+// starting with n.FirstChild.
+//
+// Mutating a Node or its children while iterating may have unexpected results.
+func (n *Node) ChildNodes() iter.Seq[*Node] {
+ _ = n.FirstChild // eager nil check
+
+ return func(yield func(*Node) bool) {
+ for c := n.FirstChild; c != nil && yield(c); c = c.NextSibling {
+ }
+ }
+
+}
+
+// Descendants returns an iterator over all nodes recursively beneath
+// n, excluding n itself. Nodes are visited in depth-first preorder.
+//
+// Mutating a Node or its descendants while iterating may have unexpected results.
+func (n *Node) Descendants() iter.Seq[*Node] {
+ _ = n.FirstChild // eager nil check
+
+ return func(yield func(*Node) bool) {
+ n.descendants(yield)
+ }
+}
+
+func (n *Node) descendants(yield func(*Node) bool) bool {
+ for c := range n.ChildNodes() {
+ if !yield(c) || !c.descendants(yield) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/html/iter_test.go b/html/iter_test.go
new file mode 100644
index 0000000000..cca7f82f54
--- /dev/null
+++ b/html/iter_test.go
@@ -0,0 +1,96 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.23
+
+package html
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestNode_ChildNodes(t *testing.T) {
+ tests := []struct {
+ in string
+ want string
+ }{
+ {"", ""},
+ {" ", "a"},
+ {"a", "a"},
+ {" ", "a b"},
+ {"a c", "a b c"},
+ {"a d", "a b d"},
+ {"ce fi ", "a f g h"},
+ }
+ for _, test := range tests {
+ doc, err := Parse(strings.NewReader(test.in))
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Drill to
+ n := doc.FirstChild.FirstChild.NextSibling
+ var results []string
+ for c := range n.ChildNodes() {
+ results = append(results, c.Data)
+ }
+ if got := strings.Join(results, " "); got != test.want {
+ t.Errorf("ChildNodes = %q, want %q", got, test.want)
+ }
+ }
+}
+
+func TestNode_Descendants(t *testing.T) {
+ tests := []struct {
+ in string
+ want string
+ }{
+ {"", ""},
+ {" ", "a"},
+ {" ", "a b"},
+ {"b ", "a b"},
+ {" ", "a b"},
+ {"b d ", "a b c d"},
+ {"b e ", "a b c d e"},
+ {"df gj ", "a b c d e f g h i j"},
+ }
+ for _, test := range tests {
+ doc, err := Parse(strings.NewReader(test.in))
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Drill to
+ n := doc.FirstChild.FirstChild.NextSibling
+ var results []string
+ for c := range n.Descendants() {
+ results = append(results, c.Data)
+ }
+ if got := strings.Join(results, " "); got != test.want {
+ t.Errorf("Descendants = %q; want: %q", got, test.want)
+ }
+ }
+}
+
+func TestNode_Ancestors(t *testing.T) {
+ for _, size := range []int{0, 1, 2, 10, 100, 10_000} {
+ n := buildChain(size)
+ nParents := 0
+ for _ = range n.Ancestors() {
+ nParents++
+ }
+ if nParents != size {
+ t.Errorf("number of Ancestors = %d; want: %d", nParents, size)
+ }
+ }
+}
+
+func buildChain(size int) *Node {
+ child := new(Node)
+ for range size {
+ parent := child
+ child = new(Node)
+ parent.AppendChild(child)
+ }
+ return child
+}
diff --git a/html/node.go b/html/node.go
index 1350eef22c..77741a1950 100644
--- a/html/node.go
+++ b/html/node.go
@@ -38,6 +38,10 @@ var scopeMarker = Node{Type: scopeMarkerNode}
// that it looks like "a ",
"",
+ " ",
}
for _, src := range srcs {
// The next line shouldn't infinite-loop.
diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go
index 6404aaf157..d89c257ae7 100644
--- a/http/httpproxy/proxy.go
+++ b/http/httpproxy/proxy.go
@@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"net"
+ "net/netip"
"net/url"
"os"
"strings"
@@ -177,8 +178,10 @@ func (cfg *config) useProxy(addr string) bool {
if host == "localhost" {
return false
}
- ip := net.ParseIP(host)
- if ip != nil {
+ nip, err := netip.ParseAddr(host)
+ var ip net.IP
+ if err == nil {
+ ip = net.IP(nip.AsSlice())
if ip.IsLoopback() {
return false
}
@@ -360,6 +363,9 @@ type domainMatch struct {
}
func (m domainMatch) match(host, port string, ip net.IP) bool {
+ if ip != nil {
+ return false
+ }
if strings.HasSuffix(host, m.host) || (m.matchHost && host == m.host[1:]) {
return m.port == "" || m.port == port
}
diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go
index 790afdab77..a1dd2e83fd 100644
--- a/http/httpproxy/proxy_test.go
+++ b/http/httpproxy/proxy_test.go
@@ -211,6 +211,13 @@ var proxyForURLTests = []proxyForURLTest{{
},
req: "http://www.xn--fsq092h.com",
want: "",
+}, {
+ cfg: httpproxy.Config{
+ NoProxy: "example.com",
+ HTTPProxy: "proxy",
+ },
+ req: "http://[1000::%25.example.com]:123",
+ want: "http://proxy",
},
}
diff --git a/http2/client_conn_pool.go b/http2/client_conn_pool.go
index 780968d6c1..e81b73e6a7 100644
--- a/http2/client_conn_pool.go
+++ b/http2/client_conn_pool.go
@@ -8,8 +8,8 @@ package http2
import (
"context"
- "crypto/tls"
"errors"
+ "net"
"net/http"
"sync"
)
@@ -158,7 +158,7 @@ func (c *dialCall) dial(ctx context.Context, addr string) {
// This code decides which ones live or die.
// The return value used is whether c was used.
// c is never closed.
-func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c *tls.Conn) (used bool, err error) {
+func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
p.mu.Lock()
for _, cc := range p.conns[key] {
if cc.CanTakeNewRequest() {
@@ -194,8 +194,8 @@ type addConnCall struct {
err error
}
-func (c *addConnCall) run(t *Transport, key string, tc *tls.Conn) {
- cc, err := t.NewClientConn(tc)
+func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
+ cc, err := t.NewClientConn(nc)
p := c.p
p.mu.Lock()
diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go
index 5533790992..f9e9a2fdaa 100644
--- a/http2/clientconn_test.go
+++ b/http2/clientconn_test.go
@@ -10,6 +10,7 @@ package http2
import (
"bytes"
"context"
+ "crypto/tls"
"fmt"
"io"
"net/http"
@@ -19,6 +20,7 @@ import (
"time"
"golang.org/x/net/http2/hpack"
+ "golang.org/x/net/internal/gate"
)
// TestTestClientConn demonstrates usage of testClientConn.
@@ -112,27 +114,40 @@ func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientCo
cc: cc,
group: cc.t.transportTestHooks.group.(*synctestGroup),
}
- cli, srv := synctestNetPipe(tc.group)
+
+ // srv is the side controlled by the test.
+ var srv *synctestNetConn
+ if cc.tconn == nil {
+ // If cc.tconn is nil, we're being called with a new conn created by the
+ // Transport's client pool. This path skips dialing the server, and we
+ // create a test connection pair here.
+ cc.tconn, srv = synctestNetPipe(tc.group)
+ } else {
+ // If cc.tconn is non-nil, we're in a test which provides a conn to the
+ // Transport via a TLSNextProto hook. Extract the test connection pair.
+ if tc, ok := cc.tconn.(*tls.Conn); ok {
+ // Unwrap any *tls.Conn to the underlying net.Conn,
+ // to avoid dealing with encryption in tests.
+ cc.tconn = tc.NetConn()
+ }
+ srv = cc.tconn.(*synctestNetConn).peer
+ }
+
srv.SetReadDeadline(tc.group.Now())
srv.autoWait = true
tc.netconn = srv
tc.enc = hpack.NewEncoder(&tc.encbuf)
-
- // all writes and reads are finished.
- //
- // cli is the ClientConn's side, srv is the side controlled by the test.
- cc.tconn = cli
tc.fr = NewFramer(srv, srv)
tc.testConnFramer = testConnFramer{
t: t,
fr: tc.fr,
dec: hpack.NewDecoder(initialHeaderTableSize, nil),
}
-
tc.fr.SetMaxReadFrameSize(10 << 20)
t.Cleanup(func() {
tc.closeWrite()
})
+
return tc
}
@@ -148,7 +163,7 @@ func (tc *testClientConn) readClientPreface() {
}
}
-func newTestClientConn(t *testing.T, opts ...func(*Transport)) *testClientConn {
+func newTestClientConn(t *testing.T, opts ...any) *testClientConn {
t.Helper()
tt := newTestTransport(t, opts...)
@@ -192,7 +207,7 @@ func (tc *testClientConn) closeWrite() {
// testRequestBody is a Request.Body for use in tests.
type testRequestBody struct {
tc *testClientConn
- gate gate
+ gate gate.Gate
// At most one of buf or bytes can be set at any given time:
buf bytes.Buffer // specific bytes to read from the body
@@ -204,18 +219,18 @@ type testRequestBody struct {
func (tc *testClientConn) newRequestBody() *testRequestBody {
b := &testRequestBody{
tc: tc,
- gate: newGate(),
+ gate: gate.New(false),
}
return b
}
func (b *testRequestBody) unlock() {
- b.gate.unlock(b.buf.Len() > 0 || b.bytes > 0 || b.err != nil)
+ b.gate.Unlock(b.buf.Len() > 0 || b.bytes > 0 || b.err != nil)
}
// Read is called by the ClientConn to read from a request body.
func (b *testRequestBody) Read(p []byte) (n int, _ error) {
- if err := b.gate.waitAndLock(context.Background()); err != nil {
+ if err := b.gate.WaitAndLock(context.Background()); err != nil {
return 0, err
}
defer b.unlock()
@@ -244,7 +259,7 @@ func (b *testRequestBody) Close() error {
// writeBytes adds n arbitrary bytes to the body.
func (b *testRequestBody) writeBytes(n int) {
defer b.tc.sync()
- b.gate.lock()
+ b.gate.Lock()
defer b.unlock()
b.bytes += n
b.checkWrite()
@@ -254,7 +269,7 @@ func (b *testRequestBody) writeBytes(n int) {
// Write adds bytes to the body.
func (b *testRequestBody) Write(p []byte) (int, error) {
defer b.tc.sync()
- b.gate.lock()
+ b.gate.Lock()
defer b.unlock()
n, err := b.buf.Write(p)
b.checkWrite()
@@ -273,7 +288,7 @@ func (b *testRequestBody) checkWrite() {
// closeWithError sets an error which will be returned by Read.
func (b *testRequestBody) closeWithError(err error) {
defer b.tc.sync()
- b.gate.lock()
+ b.gate.Lock()
defer b.unlock()
b.err = err
}
@@ -486,7 +501,7 @@ type testTransport struct {
ccs []*testClientConn
}
-func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport {
+func newTestTransport(t *testing.T, opts ...any) *testTransport {
tt := &testTransport{
t: t,
group: newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)),
@@ -495,7 +510,17 @@ func newTestTransport(t *testing.T, opts ...func(*Transport)) *testTransport {
tr := &Transport{}
for _, o := range opts {
- o(tr)
+ switch o := o.(type) {
+ case func(*http.Transport):
+ if tr.t1 == nil {
+ tr.t1 = &http.Transport{}
+ }
+ o(tr.t1)
+ case func(*Transport):
+ o(tr)
+ case *Transport:
+ tr = o
+ }
}
tt.tr = tr
diff --git a/http2/config.go b/http2/config.go
new file mode 100644
index 0000000000..ca645d9a1a
--- /dev/null
+++ b/http2/config.go
@@ -0,0 +1,122 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "math"
+ "net/http"
+ "time"
+)
+
+// http2Config is a package-internal version of net/http.HTTP2Config.
+//
+// http.HTTP2Config was added in Go 1.24.
+// When running with a version of net/http that includes HTTP2Config,
+// we merge the configuration with the fields in Transport or Server
+// to produce an http2Config.
+//
+// Zero valued fields in http2Config are interpreted as in the
+// net/http.HTTPConfig documentation.
+//
+// Precedence order for reconciling configurations is:
+//
+// - Use the net/http.{Server,Transport}.HTTP2Config value, when non-zero.
+// - Otherwise use the http2.{Server.Transport} value.
+// - If the resulting value is zero or out of range, use a default.
+type http2Config struct {
+ MaxConcurrentStreams uint32
+ MaxDecoderHeaderTableSize uint32
+ MaxEncoderHeaderTableSize uint32
+ MaxReadFrameSize uint32
+ MaxUploadBufferPerConnection int32
+ MaxUploadBufferPerStream int32
+ SendPingTimeout time.Duration
+ PingTimeout time.Duration
+ WriteByteTimeout time.Duration
+ PermitProhibitedCipherSuites bool
+ CountError func(errType string)
+}
+
+// configFromServer merges configuration settings from
+// net/http.Server.HTTP2Config and http2.Server.
+func configFromServer(h1 *http.Server, h2 *Server) http2Config {
+ conf := http2Config{
+ MaxConcurrentStreams: h2.MaxConcurrentStreams,
+ MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
+ MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
+ MaxReadFrameSize: h2.MaxReadFrameSize,
+ MaxUploadBufferPerConnection: h2.MaxUploadBufferPerConnection,
+ MaxUploadBufferPerStream: h2.MaxUploadBufferPerStream,
+ SendPingTimeout: h2.ReadIdleTimeout,
+ PingTimeout: h2.PingTimeout,
+ WriteByteTimeout: h2.WriteByteTimeout,
+ PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites,
+ CountError: h2.CountError,
+ }
+ fillNetHTTPServerConfig(&conf, h1)
+ setConfigDefaults(&conf, true)
+ return conf
+}
+
+// configFromTransport merges configuration settings from h2 and h2.t1.HTTP2
+// (the net/http Transport).
+func configFromTransport(h2 *Transport) http2Config {
+ conf := http2Config{
+ MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
+ MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
+ MaxReadFrameSize: h2.MaxReadFrameSize,
+ SendPingTimeout: h2.ReadIdleTimeout,
+ PingTimeout: h2.PingTimeout,
+ WriteByteTimeout: h2.WriteByteTimeout,
+ }
+
+ // Unlike most config fields, where out-of-range values revert to the default,
+ // Transport.MaxReadFrameSize clips.
+ if conf.MaxReadFrameSize < minMaxFrameSize {
+ conf.MaxReadFrameSize = minMaxFrameSize
+ } else if conf.MaxReadFrameSize > maxFrameSize {
+ conf.MaxReadFrameSize = maxFrameSize
+ }
+
+ if h2.t1 != nil {
+ fillNetHTTPTransportConfig(&conf, h2.t1)
+ }
+ setConfigDefaults(&conf, false)
+ return conf
+}
+
+func setDefault[T ~int | ~int32 | ~uint32 | ~int64](v *T, minval, maxval, defval T) {
+ if *v < minval || *v > maxval {
+ *v = defval
+ }
+}
+
+func setConfigDefaults(conf *http2Config, server bool) {
+ setDefault(&conf.MaxConcurrentStreams, 1, math.MaxUint32, defaultMaxStreams)
+ setDefault(&conf.MaxEncoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
+ setDefault(&conf.MaxDecoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
+ if server {
+ setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, 1<<20)
+ } else {
+ setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, transportDefaultConnFlow)
+ }
+ if server {
+ setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, 1<<20)
+ } else {
+ setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, transportDefaultStreamFlow)
+ }
+ setDefault(&conf.MaxReadFrameSize, minMaxFrameSize, maxFrameSize, defaultMaxReadFrameSize)
+ setDefault(&conf.PingTimeout, 1, math.MaxInt64, 15*time.Second)
+}
+
+// adjustHTTP1MaxHeaderSize converts a limit in bytes on the size of an HTTP/1 header
+// to an HTTP/2 MAX_HEADER_LIST_SIZE value.
+func adjustHTTP1MaxHeaderSize(n int64) int64 {
+ // http2's count is in a slightly different unit and includes 32 bytes per pair.
+ // So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
+ const perFieldOverhead = 32 // per http2 spec
+ const typicalHeaders = 10 // conservative
+ return n + typicalHeaders*perFieldOverhead
+}
diff --git a/http2/config_go124.go b/http2/config_go124.go
new file mode 100644
index 0000000000..5b516c55ff
--- /dev/null
+++ b/http2/config_go124.go
@@ -0,0 +1,61 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http2
+
+import "net/http"
+
+// fillNetHTTPServerConfig sets fields in conf from srv.HTTP2.
+func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {
+ fillNetHTTPConfig(conf, srv.HTTP2)
+}
+
+// fillNetHTTPTransportConfig sets fields in conf from tr.HTTP2.
+func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {
+ fillNetHTTPConfig(conf, tr.HTTP2)
+}
+
+func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) {
+ if h2 == nil {
+ return
+ }
+ if h2.MaxConcurrentStreams != 0 {
+ conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
+ }
+ if h2.MaxEncoderHeaderTableSize != 0 {
+ conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize)
+ }
+ if h2.MaxDecoderHeaderTableSize != 0 {
+ conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize)
+ }
+ if h2.MaxConcurrentStreams != 0 {
+ conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
+ }
+ if h2.MaxReadFrameSize != 0 {
+ conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize)
+ }
+ if h2.MaxReceiveBufferPerConnection != 0 {
+ conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection)
+ }
+ if h2.MaxReceiveBufferPerStream != 0 {
+ conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream)
+ }
+ if h2.SendPingTimeout != 0 {
+ conf.SendPingTimeout = h2.SendPingTimeout
+ }
+ if h2.PingTimeout != 0 {
+ conf.PingTimeout = h2.PingTimeout
+ }
+ if h2.WriteByteTimeout != 0 {
+ conf.WriteByteTimeout = h2.WriteByteTimeout
+ }
+ if h2.PermitProhibitedCipherSuites {
+ conf.PermitProhibitedCipherSuites = true
+ }
+ if h2.CountError != nil {
+ conf.CountError = h2.CountError
+ }
+}
diff --git a/http2/config_pre_go124.go b/http2/config_pre_go124.go
new file mode 100644
index 0000000000..060fd6c64c
--- /dev/null
+++ b/http2/config_pre_go124.go
@@ -0,0 +1,16 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build !go1.24
+
+package http2
+
+import "net/http"
+
+// Pre-Go 1.24 fallback.
+// The Server.HTTP2 and Transport.HTTP2 config fields were added in Go 1.24.
+
+func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {}
+
+func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {}
diff --git a/http2/config_test.go b/http2/config_test.go
new file mode 100644
index 0000000000..b8e7a7b043
--- /dev/null
+++ b/http2/config_test.go
@@ -0,0 +1,95 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http2
+
+import (
+ "net/http"
+ "testing"
+ "time"
+)
+
+func TestConfigServerSettings(t *testing.T) {
+ config := &http.HTTP2Config{
+ MaxConcurrentStreams: 1,
+ MaxDecoderHeaderTableSize: 1<<20 + 2,
+ MaxEncoderHeaderTableSize: 1<<20 + 3,
+ MaxReadFrameSize: 1<<20 + 4,
+ MaxReceiveBufferPerConnection: 64<<10 + 5,
+ MaxReceiveBufferPerStream: 64<<10 + 6,
+ }
+ const maxHeaderBytes = 4096 + 7
+ st := newServerTester(t, nil, func(s *http.Server) {
+ s.MaxHeaderBytes = maxHeaderBytes
+ s.HTTP2 = config
+ })
+ st.writePreface()
+ st.writeSettings()
+ st.wantSettings(map[SettingID]uint32{
+ SettingMaxConcurrentStreams: uint32(config.MaxConcurrentStreams),
+ SettingHeaderTableSize: uint32(config.MaxDecoderHeaderTableSize),
+ SettingInitialWindowSize: uint32(config.MaxReceiveBufferPerStream),
+ SettingMaxFrameSize: uint32(config.MaxReadFrameSize),
+ SettingMaxHeaderListSize: maxHeaderBytes + (32 * 10),
+ })
+}
+
+func TestConfigTransportSettings(t *testing.T) {
+ config := &http.HTTP2Config{
+ MaxConcurrentStreams: 1, // ignored by Transport
+ MaxDecoderHeaderTableSize: 1<<20 + 2,
+ MaxEncoderHeaderTableSize: 1<<20 + 3,
+ MaxReadFrameSize: 1<<20 + 4,
+ MaxReceiveBufferPerConnection: 64<<10 + 5,
+ MaxReceiveBufferPerStream: 64<<10 + 6,
+ }
+ const maxHeaderBytes = 4096 + 7
+ tc := newTestClientConn(t, func(tr *http.Transport) {
+ tr.HTTP2 = config
+ tr.MaxResponseHeaderBytes = maxHeaderBytes
+ })
+ tc.wantSettings(map[SettingID]uint32{
+ SettingHeaderTableSize: uint32(config.MaxDecoderHeaderTableSize),
+ SettingInitialWindowSize: uint32(config.MaxReceiveBufferPerStream),
+ SettingMaxFrameSize: uint32(config.MaxReadFrameSize),
+ SettingMaxHeaderListSize: maxHeaderBytes + (32 * 10),
+ })
+ tc.wantWindowUpdate(0, uint32(config.MaxReceiveBufferPerConnection))
+}
+
+func TestConfigPingTimeoutServer(t *testing.T) {
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
+ s.ReadIdleTimeout = 2 * time.Second
+ s.PingTimeout = 3 * time.Second
+ })
+ st.greet()
+
+ st.advance(2 * time.Second)
+ _ = readFrame[*PingFrame](t, st)
+ st.advance(3 * time.Second)
+ st.wantClosed()
+}
+
+func TestConfigPingTimeoutTransport(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.ReadIdleTimeout = 2 * time.Second
+ tr.PingTimeout = 3 * time.Second
+ })
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+
+ tc.advance(2 * time.Second)
+ tc.wantFrameType(FramePing)
+ tc.advance(3 * time.Second)
+ err := rt.err()
+ if err == nil {
+ t.Fatalf("expected connection to close")
+ }
+}
diff --git a/http2/connframes_test.go b/http2/connframes_test.go
index 7db8b74e2e..2c4532571a 100644
--- a/http2/connframes_test.go
+++ b/http2/connframes_test.go
@@ -6,7 +6,6 @@ package http2
import (
"bytes"
- "context"
"io"
"net/http"
"os"
@@ -262,6 +261,24 @@ func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) {
}
}
+func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) {
+ fr := readFrame[*SettingsFrame](tf.t, tf)
+ if fr.Header().Flags.Has(FlagSettingsAck) {
+ tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK")
+ }
+ for wantID, wantVal := range want {
+ gotVal, ok := fr.Value(wantID)
+ if !ok {
+ tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal)
+ } else if gotVal != wantVal {
+ tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal)
+ }
+ }
+ if tf.t.Failed() {
+ tf.t.Fatalf("%v", fr)
+ }
+}
+
func (tf *testConnFramer) wantSettingsAck() {
tf.t.Helper()
fr := readFrame[*SettingsFrame](tf.t, tf)
@@ -295,7 +312,7 @@ func (tf *testConnFramer) wantClosed() {
if err == nil {
tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
}
- if err == context.DeadlineExceeded {
+ if err == os.ErrDeadlineExceeded {
tf.t.Fatalf("connection is not closed; want it to be")
}
}
@@ -306,7 +323,7 @@ func (tf *testConnFramer) wantIdle() {
if err == nil {
tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
}
- if err != context.DeadlineExceeded {
+ if err != os.ErrDeadlineExceeded {
tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
}
}
diff --git a/http2/frame.go b/http2/frame.go
index 105c3b279c..81faec7e75 100644
--- a/http2/frame.go
+++ b/http2/frame.go
@@ -1490,7 +1490,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
pf := mh.PseudoFields()
for i, hf := range pf {
switch hf.Name {
- case ":method", ":path", ":scheme", ":authority":
+ case ":method", ":path", ":scheme", ":authority", ":protocol":
isRequest = true
case ":status":
isResponse = true
@@ -1498,7 +1498,7 @@ func (mh *MetaHeadersFrame) checkPseudos() error {
return pseudoHeaderError(hf.Name)
}
// Check for duplicates.
- // This would be a bad algorithm, but N is 4.
+ // This would be a bad algorithm, but N is 5.
// And this doesn't allocate.
for _, hf2 := range pf[:i] {
if hf.Name == hf2.Name {
diff --git a/http2/http2.go b/http2/http2.go
index 003e649f30..6c18ea230b 100644
--- a/http2/http2.go
+++ b/http2/http2.go
@@ -19,8 +19,9 @@ import (
"bufio"
"context"
"crypto/tls"
+ "errors"
"fmt"
- "io"
+ "net"
"net/http"
"os"
"sort"
@@ -37,6 +38,15 @@ var (
logFrameWrites bool
logFrameReads bool
inTests bool
+
+ // Enabling extended CONNECT by causes browsers to attempt to use
+ // WebSockets-over-HTTP/2. This results in problems when the server's websocket
+ // package doesn't support extended CONNECT.
+ //
+ // Disable extended CONNECT by default for now.
+ //
+ // Issue #71128.
+ disableExtendedConnectProtocol = true
)
func init() {
@@ -49,6 +59,9 @@ func init() {
logFrameWrites = true
logFrameReads = true
}
+ if strings.Contains(e, "http2xconnect=1") {
+ disableExtendedConnectProtocol = false
+ }
}
const (
@@ -140,6 +153,10 @@ func (s Setting) Valid() error {
if s.Val < 16384 || s.Val > 1<<24-1 {
return ConnectionError(ErrCodeProtocol)
}
+ case SettingEnableConnectProtocol:
+ if s.Val != 1 && s.Val != 0 {
+ return ConnectionError(ErrCodeProtocol)
+ }
}
return nil
}
@@ -149,21 +166,23 @@ func (s Setting) Valid() error {
type SettingID uint16
const (
- SettingHeaderTableSize SettingID = 0x1
- SettingEnablePush SettingID = 0x2
- SettingMaxConcurrentStreams SettingID = 0x3
- SettingInitialWindowSize SettingID = 0x4
- SettingMaxFrameSize SettingID = 0x5
- SettingMaxHeaderListSize SettingID = 0x6
+ SettingHeaderTableSize SettingID = 0x1
+ SettingEnablePush SettingID = 0x2
+ SettingMaxConcurrentStreams SettingID = 0x3
+ SettingInitialWindowSize SettingID = 0x4
+ SettingMaxFrameSize SettingID = 0x5
+ SettingMaxHeaderListSize SettingID = 0x6
+ SettingEnableConnectProtocol SettingID = 0x8
)
var settingName = map[SettingID]string{
- SettingHeaderTableSize: "HEADER_TABLE_SIZE",
- SettingEnablePush: "ENABLE_PUSH",
- SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
- SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
- SettingMaxFrameSize: "MAX_FRAME_SIZE",
- SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
+ SettingHeaderTableSize: "HEADER_TABLE_SIZE",
+ SettingEnablePush: "ENABLE_PUSH",
+ SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
+ SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
+ SettingMaxFrameSize: "MAX_FRAME_SIZE",
+ SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
+ SettingEnableConnectProtocol: "ENABLE_CONNECT_PROTOCOL",
}
func (s SettingID) String() string {
@@ -237,13 +256,19 @@ func (cw closeWaiter) Wait() {
// Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections.
type bufferedWriter struct {
- _ incomparable
- w io.Writer // immutable
- bw *bufio.Writer // non-nil when data is buffered
+ _ incomparable
+ group synctestGroupInterface // immutable
+ conn net.Conn // immutable
+ bw *bufio.Writer // non-nil when data is buffered
+ byteTimeout time.Duration // immutable, WriteByteTimeout
}
-func newBufferedWriter(w io.Writer) *bufferedWriter {
- return &bufferedWriter{w: w}
+func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
+ return &bufferedWriter{
+ group: group,
+ conn: conn,
+ byteTimeout: timeout,
+ }
}
// bufWriterPoolBufferSize is the size of bufio.Writer's
@@ -270,7 +295,7 @@ func (w *bufferedWriter) Available() int {
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer)
- bw.Reset(w.w)
+ bw.Reset((*bufferedWriterTimeoutWriter)(w))
w.bw = bw
}
return w.bw.Write(p)
@@ -288,6 +313,38 @@ func (w *bufferedWriter) Flush() error {
return err
}
+type bufferedWriterTimeoutWriter bufferedWriter
+
+func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
+ return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
+}
+
+// writeWithByteTimeout writes to conn.
+// If more than timeout passes without any bytes being written to the connection,
+// the write fails.
+func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
+ if timeout <= 0 {
+ return conn.Write(p)
+ }
+ for {
+ var now time.Time
+ if group == nil {
+ now = time.Now()
+ } else {
+ now = group.Now()
+ }
+ conn.SetWriteDeadline(now.Add(timeout))
+ nn, err := conn.Write(p[n:])
+ n += nn
+ if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
+ // Either we finished the write, made no progress, or hit the deadline.
+ // Whichever it is, we're done now.
+ conn.SetWriteDeadline(time.Time{})
+ return n, err
+ }
+ }
+}
+
func mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 {
panic("out of range")
@@ -358,23 +415,6 @@ func (s *sorter) SortStrings(ss []string) {
s.v = save
}
-// validPseudoPath reports whether v is a valid :path pseudo-header
-// value. It must be either:
-//
-// - a non-empty string starting with '/'
-// - the string '*', for OPTIONS requests.
-//
-// For now this is only used a quick check for deciding when to clean
-// up Opaque URLs before sending requests from the Transport.
-// See golang.org/issue/16847
-//
-// We used to enforce that the path also didn't start with "//", but
-// Google's GFE accepts such paths and Chrome sends them, so ignore
-// that part of the spec. See golang.org/issue/19103.
-func validPseudoPath(v string) bool {
- return (len(v) > 0 && v[0] == '/') || v == "*"
-}
-
// incomparable is a zero-width, non-comparable type. Adding it to a struct
// makes that struct also non-comparable, and generally doesn't add
// any size (as long as it's first).
diff --git a/http2/http2_test.go b/http2/http2_test.go
index b7c946b982..c7774133a7 100644
--- a/http2/http2_test.go
+++ b/http2/http2_test.go
@@ -283,3 +283,20 @@ func TestNoUnicodeStrings(t *testing.T) {
t.Fatal(err)
}
}
+
+// setForTest sets *p = v, and restores its original value in t.Cleanup.
+func setForTest[T any](t *testing.T, p *T, v T) {
+ orig := *p
+ t.Cleanup(func() {
+ *p = orig
+ })
+ *p = v
+}
+
+// must returns v if err is nil, or panics otherwise.
+func must[T any](v T, err error) T {
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
diff --git a/http2/netconn_test.go b/http2/netconn_test.go
index 8a61fbef10..5a1759579e 100644
--- a/http2/netconn_test.go
+++ b/http2/netconn_test.go
@@ -28,8 +28,11 @@ func synctestNetPipe(group *synctestGroup) (r, w *synctestNetConn) {
s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001"))
s1 := newSynctestNetConnHalf(s1addr)
s2 := newSynctestNetConnHalf(s2addr)
- return &synctestNetConn{group: group, loc: s1, rem: s2},
- &synctestNetConn{group: group, loc: s2, rem: s1}
+ r = &synctestNetConn{group: group, loc: s1, rem: s2}
+ w = &synctestNetConn{group: group, loc: s2, rem: s1}
+ r.peer = w
+ w.peer = r
+ return r, w
}
// A synctestNetConn is one endpoint of the connection created by synctestNetPipe.
@@ -43,6 +46,9 @@ type synctestNetConn struct {
// When set, group.Wait is automatically called before reads and after writes.
autoWait bool
+
+ // peer is the other endpoint.
+ peer *synctestNetConn
}
// Read reads data from the connection.
@@ -70,7 +76,7 @@ func (c *synctestNetConn) Write(b []byte) (n int, err error) {
return c.rem.write(b)
}
-// IsClosed reports whether the peer has closed its end of the connection.
+// IsClosedByPeer reports whether the peer has closed its end of the connection.
func (c *synctestNetConn) IsClosedByPeer() bool {
if c.autoWait {
c.group.Wait()
diff --git a/http2/server.go b/http2/server.go
index 6c349f3ec6..b640deb0e0 100644
--- a/http2/server.go
+++ b/http2/server.go
@@ -29,6 +29,7 @@ import (
"bufio"
"bytes"
"context"
+ "crypto/rand"
"crypto/tls"
"errors"
"fmt"
@@ -49,13 +50,18 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
+ "golang.org/x/net/internal/httpcommon"
)
const (
- prefaceTimeout = 10 * time.Second
- firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
- handlerChunkWriteSize = 4 << 10
- defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
+ prefaceTimeout = 10 * time.Second
+ firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
+ handlerChunkWriteSize = 4 << 10
+ defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
+
+ // maxQueuedControlFrames is the maximum number of control frames like
+ // SETTINGS, PING and RST_STREAM that will be queued for writing before
+ // the connection is closed to prevent memory exhaustion attacks.
maxQueuedControlFrames = 10000
)
@@ -127,6 +133,22 @@ type Server struct {
// If zero or negative, there is no timeout.
IdleTimeout time.Duration
+ // ReadIdleTimeout is the timeout after which a health check using a ping
+ // frame will be carried out if no frame is received on the connection.
+ // If zero, no health check is performed.
+ ReadIdleTimeout time.Duration
+
+ // PingTimeout is the timeout after which the connection will be closed
+ // if a response to a ping is not received.
+ // If zero, a default of 15 seconds is used.
+ PingTimeout time.Duration
+
+ // WriteByteTimeout is the timeout after which a connection will be
+ // closed if no data can be written to it. The timeout begins when data is
+ // available to write, and is extended whenever any bytes are written.
+ // If zero or negative, there is no timeout.
+ WriteByteTimeout time.Duration
+
// MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1.
@@ -189,57 +211,6 @@ func (s *Server) afterFunc(d time.Duration, f func()) timer {
return timeTimer{time.AfterFunc(d, f)}
}
-func (s *Server) initialConnRecvWindowSize() int32 {
- if s.MaxUploadBufferPerConnection >= initialWindowSize {
- return s.MaxUploadBufferPerConnection
- }
- return 1 << 20
-}
-
-func (s *Server) initialStreamRecvWindowSize() int32 {
- if s.MaxUploadBufferPerStream > 0 {
- return s.MaxUploadBufferPerStream
- }
- return 1 << 20
-}
-
-func (s *Server) maxReadFrameSize() uint32 {
- if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize {
- return v
- }
- return defaultMaxReadFrameSize
-}
-
-func (s *Server) maxConcurrentStreams() uint32 {
- if v := s.MaxConcurrentStreams; v > 0 {
- return v
- }
- return defaultMaxStreams
-}
-
-func (s *Server) maxDecoderHeaderTableSize() uint32 {
- if v := s.MaxDecoderHeaderTableSize; v > 0 {
- return v
- }
- return initialHeaderTableSize
-}
-
-func (s *Server) maxEncoderHeaderTableSize() uint32 {
- if v := s.MaxEncoderHeaderTableSize; v > 0 {
- return v
- }
- return initialHeaderTableSize
-}
-
-// maxQueuedControlFrames is the maximum number of control frames like
-// SETTINGS, PING and RST_STREAM that will be queued for writing before
-// the connection is closed to prevent memory exhaustion attacks.
-func (s *Server) maxQueuedControlFrames() int {
- // TODO: if anybody asks, add a Server field, and remember to define the
- // behavior of negative values.
- return maxQueuedControlFrames
-}
-
type serverInternalState struct {
mu sync.Mutex
activeConns map[*serverConn]struct{}
@@ -336,7 +307,7 @@ func ConfigureServer(s *http.Server, conf *Server) error {
if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){}
}
- protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) {
+ protoHandler := func(hs *http.Server, c net.Conn, h http.Handler, sawClientPreface bool) {
if testHookOnConn != nil {
testHookOnConn()
}
@@ -353,12 +324,31 @@ func ConfigureServer(s *http.Server, conf *Server) error {
ctx = bc.BaseContext()
}
conf.ServeConn(c, &ServeConnOpts{
- Context: ctx,
- Handler: h,
- BaseConfig: hs,
+ Context: ctx,
+ Handler: h,
+ BaseConfig: hs,
+ SawClientPreface: sawClientPreface,
})
}
- s.TLSNextProto[NextProtoTLS] = protoHandler
+ s.TLSNextProto[NextProtoTLS] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
+ protoHandler(hs, c, h, false)
+ }
+ // The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns.
+ //
+ // A connection passed in this method has already had the HTTP/2 preface read from it.
+ s.TLSNextProto[nextProtoUnencryptedHTTP2] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
+ nc, err := unencryptedNetConnFromTLSConn(c)
+ if err != nil {
+ if lg := hs.ErrorLog; lg != nil {
+ lg.Print(err)
+ } else {
+ log.Print(err)
+ }
+ go c.Close()
+ return
+ }
+ protoHandler(hs, nc, h, true)
+ }
return nil
}
@@ -440,13 +430,15 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
baseCtx, cancel := serverConnBaseContext(c, opts)
defer cancel()
+ http1srv := opts.baseConfig()
+ conf := configFromServer(http1srv, s)
sc := &serverConn{
srv: s,
- hs: opts.baseConfig(),
+ hs: http1srv,
conn: c,
baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(),
- bw: newBufferedWriter(c),
+ bw: newBufferedWriter(s.group, c, conf.WriteByteTimeout),
handler: opts.handler(),
streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult),
@@ -456,9 +448,12 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}),
clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
- advMaxStreams: s.maxConcurrentStreams(),
+ advMaxStreams: conf.MaxConcurrentStreams,
initialStreamSendWindowSize: initialWindowSize,
+ initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
maxFrameSize: initialMaxFrameSize,
+ pingTimeout: conf.PingTimeout,
+ countErrorFunc: conf.CountError,
serveG: newGoroutineLock(),
pushEnabled: true,
sawClientPreface: opts.SawClientPreface,
@@ -491,15 +486,15 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
sc.flow.add(initialWindowSize)
sc.inflow.init(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
- sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize())
+ sc.hpackEncoder.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
fr := NewFramer(sc.bw, c)
- if s.CountError != nil {
- fr.countError = s.CountError
+ if conf.CountError != nil {
+ fr.countError = conf.CountError
}
- fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil)
+ fr.ReadMetaHeaders = hpack.NewDecoder(conf.MaxDecoderHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize()
- fr.SetMaxReadFrameSize(s.maxReadFrameSize())
+ fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
sc.framer = fr
if tc, ok := c.(connectionStater); ok {
@@ -532,7 +527,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
// So for now, do nothing here again.
}
- if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
+ if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated."
@@ -569,7 +564,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
opts.UpgradeRequest = nil
}
- sc.serve()
+ sc.serve(conf)
}
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
@@ -609,6 +604,7 @@ type serverConn struct {
tlsState *tls.ConnectionState // shared by all handlers, like net/http
remoteAddrStr string
writeSched WriteScheduler
+ countErrorFunc func(errType string)
// Everything following is owned by the serve loop; use serveG.check():
serveG goroutineLock // used to verify funcs are on serve()
@@ -628,6 +624,7 @@ type serverConn struct {
streams map[uint32]*stream
unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32
+ initialStreamRecvWindowSize int32
maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
@@ -638,9 +635,14 @@ type serverConn struct {
inGoAway bool // we've started to or sent GOAWAY
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write
+ pingSent bool
+ sentPingData [8]byte
goAwayCode ErrCode
shutdownTimer timer // nil until used
idleTimer timer // nil if unused
+ readIdleTimeout time.Duration
+ pingTimeout time.Duration
+ readIdleTimer timer // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
@@ -655,11 +657,7 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
if n <= 0 {
n = http.DefaultMaxHeaderBytes
}
- // http2's count is in a slightly different unit and includes 32 bytes per pair.
- // So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
- const perFieldOverhead = 32 // per http2 spec
- const typicalHeaders = 10 // conservative
- return uint32(n + typicalHeaders*perFieldOverhead)
+ return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
}
func (sc *serverConn) curOpenStreams() uint32 {
@@ -815,8 +813,7 @@ const maxCachedCanonicalHeadersKeysSize = 2048
func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check()
- buildCommonHeaderMapsOnce()
- cv, ok := commonCanonHeader[v]
+ cv, ok := httpcommon.CachedCanonicalHeader(v)
if ok {
return cv
}
@@ -923,7 +920,7 @@ func (sc *serverConn) notePanic() {
}
}
-func (sc *serverConn) serve() {
+func (sc *serverConn) serve(conf http2Config) {
sc.serveG.check()
defer sc.notePanic()
defer sc.conn.Close()
@@ -935,20 +932,24 @@ func (sc *serverConn) serve() {
sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
}
+ settings := writeSettings{
+ {SettingMaxFrameSize, conf.MaxReadFrameSize},
+ {SettingMaxConcurrentStreams, sc.advMaxStreams},
+ {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
+ {SettingHeaderTableSize, conf.MaxDecoderHeaderTableSize},
+ {SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
+ }
+ if !disableExtendedConnectProtocol {
+ settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
+ }
sc.writeFrame(FrameWriteRequest{
- write: writeSettings{
- {SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
- {SettingMaxConcurrentStreams, sc.advMaxStreams},
- {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
- {SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()},
- {SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())},
- },
+ write: settings,
})
sc.unackedSettings++
// Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens.
- if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 {
+ if diff := conf.MaxUploadBufferPerConnection - initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff))
}
@@ -968,11 +969,18 @@ func (sc *serverConn) serve() {
defer sc.idleTimer.Stop()
}
+ if conf.SendPingTimeout > 0 {
+ sc.readIdleTimeout = conf.SendPingTimeout
+ sc.readIdleTimer = sc.srv.afterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
+ defer sc.readIdleTimer.Stop()
+ }
+
go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop()
+ lastFrameTime := sc.srv.now()
loopNum := 0
for {
loopNum++
@@ -986,6 +994,7 @@ func (sc *serverConn) serve() {
case res := <-sc.wroteFrameCh:
sc.wroteFrame(res)
case res := <-sc.readFrameCh:
+ lastFrameTime = sc.srv.now()
// Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started.
if sc.writingFrameAsync {
@@ -1017,6 +1026,8 @@ func (sc *serverConn) serve() {
case idleTimerMsg:
sc.vlogf("connection is idle")
sc.goAway(ErrCodeNo)
+ case readIdleTimerMsg:
+ sc.handlePingTimer(lastFrameTime)
case shutdownTimerMsg:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
@@ -1039,7 +1050,7 @@ func (sc *serverConn) serve() {
// If the peer is causing us to generate a lot of control frames,
// but not reading them from us, assume they are trying to make us
// run out of memory.
- if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() {
+ if sc.queuedControlFrames > maxQueuedControlFrames {
sc.vlogf("http2: too many control frames in send queue, closing connection")
return
}
@@ -1055,12 +1066,39 @@ func (sc *serverConn) serve() {
}
}
+func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
+ if sc.pingSent {
+ sc.vlogf("timeout waiting for PING response")
+ sc.conn.Close()
+ return
+ }
+
+ pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
+ now := sc.srv.now()
+ if pingAt.After(now) {
+ // We received frames since arming the ping timer.
+ // Reset it for the next possible timeout.
+ sc.readIdleTimer.Reset(pingAt.Sub(now))
+ return
+ }
+
+ sc.pingSent = true
+ // Ignore crypto/rand.Read errors: It generally can't fail, and worse case if it does
+ // is we send a PING frame containing 0s.
+ _, _ = rand.Read(sc.sentPingData[:])
+ sc.writeFrame(FrameWriteRequest{
+ write: &writePing{data: sc.sentPingData},
+ })
+ sc.readIdleTimer.Reset(sc.pingTimeout)
+}
+
type serverMessage int
// Message values sent to serveMsgCh.
var (
settingsTimerMsg = new(serverMessage)
idleTimerMsg = new(serverMessage)
+ readIdleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage)
@@ -1068,6 +1106,7 @@ var (
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
+func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
func (sc *serverConn) sendServeMsg(msg interface{}) {
@@ -1320,6 +1359,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
sc.writingFrame = false
sc.writingFrameAsync = false
+ if res.err != nil {
+ sc.conn.Close()
+ }
+
wr := res.wr
if writeEndsStream(wr.write) {
@@ -1594,6 +1637,11 @@ func (sc *serverConn) processFrame(f Frame) error {
func (sc *serverConn) processPing(f *PingFrame) error {
sc.serveG.check()
if f.IsAck() {
+ if sc.pingSent && sc.sentPingData == f.Data {
+ // This is a response to a PING we sent.
+ sc.pingSent = false
+ sc.readIdleTimer.Reset(sc.readIdleTimeout)
+ }
// 6.7 PING: " An endpoint MUST NOT respond to PING frames
// containing this flag."
return nil
@@ -1757,6 +1805,9 @@ func (sc *serverConn) processSetting(s Setting) error {
sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31
case SettingMaxHeaderListSize:
sc.peerMaxHeaderListSize = s.Val
+ case SettingEnableConnectProtocol:
+ // Receipt of this parameter by a server does not
+ // have any impact
default:
// Unknown setting: "An endpoint that receives a SETTINGS
// frame with any unknown or unsupported identifier MUST
@@ -2160,7 +2211,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.cw.Init()
st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize)
- st.inflow.init(sc.srv.initialStreamRecvWindowSize())
+ st.inflow.init(sc.initialStreamRecvWindowSize)
if sc.hs.WriteTimeout > 0 {
st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
}
@@ -2182,19 +2233,25 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
sc.serveG.check()
- rp := requestParam{
- method: f.PseudoValue("method"),
- scheme: f.PseudoValue("scheme"),
- authority: f.PseudoValue("authority"),
- path: f.PseudoValue("path"),
+ rp := httpcommon.ServerRequestParam{
+ Method: f.PseudoValue("method"),
+ Scheme: f.PseudoValue("scheme"),
+ Authority: f.PseudoValue("authority"),
+ Path: f.PseudoValue("path"),
+ Protocol: f.PseudoValue("protocol"),
}
- isConnect := rp.method == "CONNECT"
+ // extended connect is disabled, so we should not see :protocol
+ if disableExtendedConnectProtocol && rp.Protocol != "" {
+ return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
+ }
+
+ isConnect := rp.Method == "CONNECT"
if isConnect {
- if rp.path != "" || rp.scheme != "" || rp.authority == "" {
+ if rp.Protocol == "" && (rp.Path != "" || rp.Scheme != "" || rp.Authority == "") {
return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
}
- } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
+ } else if rp.Method == "" || rp.Path == "" || (rp.Scheme != "https" && rp.Scheme != "http") {
// See 8.1.2.6 Malformed Requests and Responses:
//
// Malformed requests or responses that are detected
@@ -2208,12 +2265,16 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
}
- rp.header = make(http.Header)
+ header := make(http.Header)
+ rp.Header = header
for _, hf := range f.RegularFields() {
- rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+ header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+ }
+ if rp.Authority == "" {
+ rp.Authority = header.Get("Host")
}
- if rp.authority == "" {
- rp.authority = rp.header.Get("Host")
+ if rp.Protocol != "" {
+ header.Set(":protocol", rp.Protocol)
}
rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
@@ -2222,7 +2283,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
}
bodyOpen := !f.StreamEnded()
if bodyOpen {
- if vv, ok := rp.header["Content-Length"]; ok {
+ if vv, ok := rp.Header["Content-Length"]; ok {
if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
req.ContentLength = int64(cl)
} else {
@@ -2238,83 +2299,38 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res
return rw, req, nil
}
-type requestParam struct {
- method string
- scheme, authority, path string
- header http.Header
-}
-
-func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) {
+func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *http.Request, error) {
sc.serveG.check()
var tlsState *tls.ConnectionState // nil if not scheme https
- if rp.scheme == "https" {
+ if rp.Scheme == "https" {
tlsState = sc.tlsState
}
- needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue")
- if needsContinue {
- rp.header.Del("Expect")
- }
- // Merge Cookie headers into one "; "-delimited value.
- if cookies := rp.header["Cookie"]; len(cookies) > 1 {
- rp.header.Set("Cookie", strings.Join(cookies, "; "))
- }
-
- // Setup Trailers
- var trailer http.Header
- for _, v := range rp.header["Trailer"] {
- for _, key := range strings.Split(v, ",") {
- key = http.CanonicalHeaderKey(textproto.TrimString(key))
- switch key {
- case "Transfer-Encoding", "Trailer", "Content-Length":
- // Bogus. (copy of http1 rules)
- // Ignore.
- default:
- if trailer == nil {
- trailer = make(http.Header)
- }
- trailer[key] = nil
- }
- }
- }
- delete(rp.header, "Trailer")
-
- var url_ *url.URL
- var requestURI string
- if rp.method == "CONNECT" {
- url_ = &url.URL{Host: rp.authority}
- requestURI = rp.authority // mimic HTTP/1 server behavior
- } else {
- var err error
- url_, err = url.ParseRequestURI(rp.path)
- if err != nil {
- return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol))
- }
- requestURI = rp.path
+ res := httpcommon.NewServerRequest(rp)
+ if res.InvalidReason != "" {
+ return nil, nil, sc.countError(res.InvalidReason, streamError(st.id, ErrCodeProtocol))
}
body := &requestBody{
conn: sc,
stream: st,
- needsContinue: needsContinue,
+ needsContinue: res.NeedsContinue,
}
- req := &http.Request{
- Method: rp.method,
- URL: url_,
+ req := (&http.Request{
+ Method: rp.Method,
+ URL: res.URL,
RemoteAddr: sc.remoteAddrStr,
- Header: rp.header,
- RequestURI: requestURI,
+ Header: rp.Header,
+ RequestURI: res.RequestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
TLS: tlsState,
- Host: rp.authority,
+ Host: rp.Authority,
Body: body,
- Trailer: trailer,
- }
- req = req.WithContext(st.ctx)
-
+ Trailer: res.Trailer,
+ }).WithContext(st.ctx)
rw := sc.newResponseWriter(st, req)
return rw, req, nil
}
@@ -2855,6 +2871,11 @@ func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
return nil
}
+func (w *responseWriter) EnableFullDuplex() error {
+ // We always support full duplex responses, so this is a no-op.
+ return nil
+}
+
func (w *responseWriter) Flush() {
w.FlushError()
}
@@ -3204,12 +3225,12 @@ func (sc *serverConn) startPush(msg *startPushRequest) {
// we start in "half closed (remote)" for simplicity.
// See further comments at the definition of stateHalfClosedRemote.
promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote)
- rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{
- method: msg.method,
- scheme: msg.url.Scheme,
- authority: msg.url.Host,
- path: msg.url.RequestURI(),
- header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
+ rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{
+ Method: msg.method,
+ Scheme: msg.url.Scheme,
+ Authority: msg.url.Host,
+ Path: msg.url.RequestURI(),
+ Header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE
})
if err != nil {
// Should not happen, since we've already validated msg.url.
@@ -3301,7 +3322,7 @@ func (sc *serverConn) countError(name string, err error) error {
if sc == nil || sc.srv == nil {
return err
}
- f := sc.srv.CountError
+ f := sc.countErrorFunc
if f == nil {
return err
}
diff --git a/http2/server_test.go b/http2/server_test.go
index 47c3c619c0..b27a127a5e 100644
--- a/http2/server_test.go
+++ b/http2/server_test.go
@@ -333,7 +333,9 @@ func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts ..
// sync waits for all goroutines to idle.
func (st *serverTester) sync() {
- st.group.Wait()
+ if st.group != nil {
+ st.group.Wait()
+ }
}
// advance advances synthetic time by a duration.
@@ -461,7 +463,8 @@ func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error
if f.FrameHeader.StreamID != 0 {
st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
}
- incr := uint32(st.sc.srv.initialConnRecvWindowSize() - initialWindowSize)
+ conf := configFromServer(st.sc.hs, st.sc.srv)
+ incr := uint32(conf.MaxUploadBufferPerConnection - initialWindowSize)
if f.Increment != incr {
st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr)
}
@@ -589,11 +592,12 @@ func (st *serverTester) bodylessReq1(headers ...string) {
}
func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) {
+ conf := configFromServer(st.sc.hs, st.sc.srv)
var initial int32
if streamID == 0 {
- initial = st.sc.srv.initialConnRecvWindowSize()
+ initial = conf.MaxUploadBufferPerConnection
} else {
- initial = st.sc.srv.initialStreamRecvWindowSize()
+ initial = conf.MaxUploadBufferPerStream
}
donec := make(chan struct{})
st.sc.sendServeMsg(func(sc *serverConn) {
@@ -1028,6 +1032,26 @@ func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
})
}
+func TestServer_Request_Reject_Authority_Userinfo(t *testing.T) {
+ // "':authority' MUST NOT include the deprecated userinfo subcomponent
+ // for "http" or "https" schemed URIs."
+ // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8
+ testRejectRequest(t, func(st *serverTester) {
+ var buf bytes.Buffer
+ enc := hpack.NewEncoder(&buf)
+ enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "userinfo@example.tld"})
+ enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
+ enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
+ enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1, // clients send odd numbers
+ BlockFragment: buf.Bytes(),
+ EndStream: true,
+ EndHeaders: true,
+ })
+ })
+}
+
func testRejectRequest(t *testing.T, send func(*serverTester)) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
t.Error("server request made it to handler; should've been rejected")
@@ -2790,6 +2814,8 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) {
w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2")
return nil
}, func(st *serverTester) {
+ // Ignore errors from writing invalid trailers.
+ st.h1server.ErrorLog = log.New(io.Discard, "", 0)
getSlash(st)
st.wantHeaders(wantHeader{
streamID: 1,
@@ -2894,15 +2920,10 @@ func BenchmarkServerGets(b *testing.B) {
EndStream: true,
EndHeaders: true,
})
- st.wantHeaders(wantHeader{
- streamID: 1,
- endStream: true,
- })
- st.wantData(wantData{
- streamID: 1,
- endStream: true,
- size: 0,
- })
+ st.wantFrameType(FrameHeaders)
+ if df := readFrame[*DataFrame](b, st); !df.StreamEnded() {
+ b.Fatalf("DATA didn't have END_STREAM; got %v", df)
+ }
}
}
@@ -2937,15 +2958,10 @@ func BenchmarkServerPosts(b *testing.B) {
EndHeaders: true,
})
st.writeData(id, true, nil)
- st.wantHeaders(wantHeader{
- streamID: 1,
- endStream: false,
- })
- st.wantData(wantData{
- streamID: 1,
- endStream: true,
- size: 0,
- })
+ st.wantFrameType(FrameHeaders)
+ if df := readFrame[*DataFrame](b, st); !df.StreamEnded() {
+ b.Fatalf("DATA didn't have END_STREAM; got %v", df)
+ }
}
}
@@ -3287,14 +3303,8 @@ func BenchmarkServer_GetRequest(b *testing.B) {
EndStream: true,
EndHeaders: true,
})
- st.wantHeaders(wantHeader{
- streamID: streamID,
- endStream: false,
- })
- st.wantData(wantData{
- streamID: streamID,
- endStream: true,
- })
+ st.wantFrameType(FrameHeaders)
+ st.wantFrameType(FrameData)
}
}
@@ -3325,14 +3335,8 @@ func BenchmarkServer_PostRequest(b *testing.B) {
EndHeaders: true,
})
st.writeData(streamID, true, nil)
- st.wantHeaders(wantHeader{
- streamID: streamID,
- endStream: false,
- })
- st.wantData(wantData{
- streamID: streamID,
- endStream: true,
- })
+ st.wantFrameType(FrameHeaders)
+ st.wantFrameType(FrameData)
}
}
@@ -4345,7 +4349,7 @@ func TestCanonicalHeaderCacheGrowth(t *testing.T) {
}
}
-// TestServerWriteDoesNotRetainBufferAfterStreamClose checks for access to
+// TestServerWriteDoesNotRetainBufferAfterReturn checks for access to
// the slice passed to ResponseWriter.Write after Write returns.
//
// Terminating the request stream on the client causes Write to return.
@@ -4674,3 +4678,78 @@ func TestServerSetReadWriteDeadlineRace(t *testing.T) {
}
resp.Body.Close()
}
+
+func TestServerWriteByteTimeout(t *testing.T) {
+ const timeout = 1 * time.Second
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ w.Write(make([]byte, 100))
+ }, func(s *Server) {
+ s.WriteByteTimeout = timeout
+ })
+ st.greet()
+
+ st.cc.(*synctestNetConn).SetReadBufferSize(1) // write one byte at a time
+ st.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ BlockFragment: st.encodeHeader(),
+ EndStream: true,
+ EndHeaders: true,
+ })
+
+ // Read a few bytes, staying just under WriteByteTimeout.
+ for i := 0; i < 10; i++ {
+ st.advance(timeout - 1)
+ if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil {
+ t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err)
+ }
+ }
+
+ // Wait for WriteByteTimeout.
+ // The connection should close.
+ st.advance(1 * time.Second) // timeout after writing one byte
+ st.advance(1 * time.Second) // timeout after failing to write any more bytes
+ st.wantClosed()
+}
+
+func TestServerPingSent(t *testing.T) {
+ const readIdleTimeout = 15 * time.Second
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
+ s.ReadIdleTimeout = readIdleTimeout
+ })
+ st.greet()
+
+ st.wantIdle()
+
+ st.advance(readIdleTimeout)
+ _ = readFrame[*PingFrame](t, st)
+ st.wantIdle()
+
+ st.advance(14 * time.Second)
+ st.wantIdle()
+ st.advance(1 * time.Second)
+ st.wantClosed()
+}
+
+func TestServerPingResponded(t *testing.T) {
+ const readIdleTimeout = 15 * time.Second
+ st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
+ }, func(s *Server) {
+ s.ReadIdleTimeout = readIdleTimeout
+ })
+ st.greet()
+
+ st.wantIdle()
+
+ st.advance(readIdleTimeout)
+ pf := readFrame[*PingFrame](t, st)
+ st.wantIdle()
+
+ st.advance(14 * time.Second)
+ st.wantIdle()
+
+ st.writePing(true, pf.Data)
+
+ st.advance(2 * time.Second)
+ st.wantIdle()
+}
diff --git a/http2/sync_test.go b/http2/sync_test.go
index aeddbd6f3c..6687202d2c 100644
--- a/http2/sync_test.go
+++ b/http2/sync_test.go
@@ -24,9 +24,10 @@ type synctestGroup struct {
}
type goroutine struct {
- id int
- parent int
- state string
+ id int
+ parent int
+ state string
+ syscall bool
}
// newSynctest creates a new group with the synthetic clock set the provided time.
@@ -76,6 +77,14 @@ func (g *synctestGroup) Wait() {
return
}
runtime.Gosched()
+ if runtime.GOOS == "js" {
+ // When GOOS=js, we appear to need to time.Sleep to make progress
+ // on some syscalls. In particular, without this sleep
+ // writing to stdout (including via t.Log) can block forever.
+ for range 10 {
+ time.Sleep(1)
+ }
+ }
}
}
@@ -87,6 +96,9 @@ func (g *synctestGroup) idle() bool {
if !g.gids[gr.id] && !g.gids[gr.parent] {
continue
}
+ if gr.syscall {
+ return false
+ }
// From runtime/runtime2.go.
switch gr.state {
case "IO wait":
@@ -97,9 +109,6 @@ func (g *synctestGroup) idle() bool {
case "chan receive":
case "chan send":
case "sync.Cond.Wait":
- case "sync.Mutex.Lock":
- case "sync.RWMutex.RLock":
- case "sync.RWMutex.Lock":
default:
return false
}
@@ -138,6 +147,10 @@ func stacks(all bool) []goroutine {
panic(fmt.Errorf("3 unparsable goroutine stack:\n%s", gs))
}
state, rest, ok := strings.Cut(rest, "]")
+ isSyscall := false
+ if strings.Contains(rest, "\nsyscall.") {
+ isSyscall = true
+ }
var parent int
_, rest, ok = strings.Cut(rest, "\ncreated by ")
if ok && strings.Contains(rest, " in goroutine ") {
@@ -155,9 +168,10 @@ func stacks(all bool) []goroutine {
}
}
goroutines = append(goroutines, goroutine{
- id: id,
- parent: parent,
- state: state,
+ id: id,
+ parent: parent,
+ state: state,
+ syscall: isSyscall,
})
}
return goroutines
@@ -291,3 +305,25 @@ func (tm *fakeTimer) Stop() bool {
delete(tm.g.timers, tm)
return stopped
}
+
+// TestSynctestLogs verifies that t.Log works,
+// in particular that the GOOS=js workaround in synctestGroup.Wait is working.
+// (When GOOS=js, writing to stdout can hang indefinitely if some goroutine loops
+// calling runtime.Gosched; see Wait for the workaround.)
+func TestSynctestLogs(t *testing.T) {
+ g := newSynctest(time.Now())
+ donec := make(chan struct{})
+ go func() {
+ g.Join()
+ for range 100 {
+ t.Logf("logging a long line")
+ }
+ close(donec)
+ }()
+ g.Wait()
+ select {
+ case <-donec:
+ default:
+ panic("done")
+ }
+}
diff --git a/http2/transport.go b/http2/transport.go
index 61f511f97a..f26356b9cd 100644
--- a/http2/transport.go
+++ b/http2/transport.go
@@ -25,8 +25,6 @@ import (
"net/http"
"net/http/httptrace"
"net/textproto"
- "os"
- "sort"
"strconv"
"strings"
"sync"
@@ -36,6 +34,7 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
+ "golang.org/x/net/internal/httpcommon"
)
const (
@@ -203,6 +202,20 @@ func (t *Transport) markNewGoroutine() {
}
}
+func (t *Transport) now() time.Time {
+ if t != nil && t.transportTestHooks != nil {
+ return t.transportTestHooks.group.Now()
+ }
+ return time.Now()
+}
+
+func (t *Transport) timeSince(when time.Time) time.Duration {
+ if t != nil && t.transportTestHooks != nil {
+ return t.now().Sub(when)
+ }
+ return time.Since(when)
+}
+
// newTimer creates a new time.Timer, or a synthetic timer in tests.
func (t *Transport) newTimer(d time.Duration) timer {
if t.transportTestHooks != nil {
@@ -227,40 +240,26 @@ func (t *Transport) contextWithTimeout(ctx context.Context, d time.Duration) (co
}
func (t *Transport) maxHeaderListSize() uint32 {
- if t.MaxHeaderListSize == 0 {
+ n := int64(t.MaxHeaderListSize)
+ if t.t1 != nil && t.t1.MaxResponseHeaderBytes != 0 {
+ n = t.t1.MaxResponseHeaderBytes
+ if n > 0 {
+ n = adjustHTTP1MaxHeaderSize(n)
+ }
+ }
+ if n <= 0 {
return 10 << 20
}
- if t.MaxHeaderListSize == 0xffffffff {
+ if n >= 0xffffffff {
return 0
}
- return t.MaxHeaderListSize
-}
-
-func (t *Transport) maxFrameReadSize() uint32 {
- if t.MaxReadFrameSize == 0 {
- return 0 // use the default provided by the peer
- }
- if t.MaxReadFrameSize < minMaxFrameSize {
- return minMaxFrameSize
- }
- if t.MaxReadFrameSize > maxFrameSize {
- return maxFrameSize
- }
- return t.MaxReadFrameSize
+ return uint32(n)
}
func (t *Transport) disableCompression() bool {
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}
-func (t *Transport) pingTimeout() time.Duration {
- if t.PingTimeout == 0 {
- return 15 * time.Second
- }
- return t.PingTimeout
-
-}
-
// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
// It returns an error if t1 has already been HTTP/2-enabled.
//
@@ -296,8 +295,8 @@ func configureTransports(t1 *http.Transport) (*Transport, error) {
if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
}
- upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
- addr := authorityAddr("https", authority)
+ upgradeFn := func(scheme, authority string, c net.Conn) http.RoundTripper {
+ addr := authorityAddr(scheme, authority)
if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
go c.Close()
return erringRoundTripper{err}
@@ -308,18 +307,37 @@ func configureTransports(t1 *http.Transport) (*Transport, error) {
// was unknown)
go c.Close()
}
+ if scheme == "http" {
+ return (*unencryptedTransport)(t2)
+ }
return t2
}
- if m := t1.TLSNextProto; len(m) == 0 {
- t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
- "h2": upgradeFn,
+ if t1.TLSNextProto == nil {
+ t1.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
+ }
+ t1.TLSNextProto[NextProtoTLS] = func(authority string, c *tls.Conn) http.RoundTripper {
+ return upgradeFn("https", authority, c)
+ }
+ // The "unencrypted_http2" TLSNextProto key is used to pass off non-TLS HTTP/2 conns.
+ t1.TLSNextProto[nextProtoUnencryptedHTTP2] = func(authority string, c *tls.Conn) http.RoundTripper {
+ nc, err := unencryptedNetConnFromTLSConn(c)
+ if err != nil {
+ go c.Close()
+ return erringRoundTripper{err}
}
- } else {
- m["h2"] = upgradeFn
+ return upgradeFn("http", authority, nc)
}
return t2, nil
}
+// unencryptedTransport is a Transport with a RoundTrip method that
+// always permits http:// URLs.
+type unencryptedTransport Transport
+
+func (t *unencryptedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
+ return (*Transport)(t).RoundTripOpt(req, RoundTripOpt{allowHTTP: true})
+}
+
func (t *Transport) connPool() ClientConnPool {
t.connPoolOnce.Do(t.initConnPool)
return t.connPoolOrDef
@@ -339,7 +357,7 @@ type ClientConn struct {
t *Transport
tconn net.Conn // usually *tls.Conn, except specialized impls
tlsState *tls.ConnectionState // nil only for specialized impls
- reused uint32 // whether conn is being reused; atomic
+ atomicReused uint32 // whether conn is being reused; atomic
singleUse bool // whether being used for a single http.Request
getConnCalled bool // used by clientConnPool
@@ -350,31 +368,55 @@ type ClientConn struct {
idleTimeout time.Duration // or 0 for never
idleTimer timer
- mu sync.Mutex // guards following
- cond *sync.Cond // hold mu; broadcast on flow/closed changes
- flow outflow // our conn-level flow control quota (cs.outflow is per stream)
- inflow inflow // peer's conn-level flow control
- doNotReuse bool // whether conn is marked to not be reused for any future requests
- closing bool
- closed bool
- seenSettings bool // true if we've seen a settings frame, false otherwise
- wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
- goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
- goAwayDebug string // goAway frame's debug data, retained as a string
- streams map[uint32]*clientStream // client-initiated
- streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
- nextStreamID uint32
- pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
- pings map[[8]byte]chan struct{} // in flight ping data to notification channel
- br *bufio.Reader
- lastActive time.Time
- lastIdle time.Time // time last idle
+ mu sync.Mutex // guards following
+ cond *sync.Cond // hold mu; broadcast on flow/closed changes
+ flow outflow // our conn-level flow control quota (cs.outflow is per stream)
+ inflow inflow // peer's conn-level flow control
+ doNotReuse bool // whether conn is marked to not be reused for any future requests
+ closing bool
+ closed bool
+ closedOnIdle bool // true if conn was closed for idleness
+ seenSettings bool // true if we've seen a settings frame, false otherwise
+ seenSettingsChan chan struct{} // closed when seenSettings is true or frame reading fails
+ wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
+ goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
+ goAwayDebug string // goAway frame's debug data, retained as a string
+ streams map[uint32]*clientStream // client-initiated
+ streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip
+ nextStreamID uint32
+ pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams
+ pings map[[8]byte]chan struct{} // in flight ping data to notification channel
+ br *bufio.Reader
+ lastActive time.Time
+ lastIdle time.Time // time last idle
// Settings from peer: (also guarded by wmu)
- maxFrameSize uint32
- maxConcurrentStreams uint32
- peerMaxHeaderListSize uint64
- peerMaxHeaderTableSize uint32
- initialWindowSize uint32
+ maxFrameSize uint32
+ maxConcurrentStreams uint32
+ peerMaxHeaderListSize uint64
+ peerMaxHeaderTableSize uint32
+ initialWindowSize uint32
+ initialStreamRecvWindowSize int32
+ readIdleTimeout time.Duration
+ pingTimeout time.Duration
+ extendedConnectAllowed bool
+
+ // rstStreamPingsBlocked works around an unfortunate gRPC behavior.
+ // gRPC strictly limits the number of PING frames that it will receive.
+ // The default is two pings per two hours, but the limit resets every time
+ // the gRPC endpoint sends a HEADERS or DATA frame. See golang/go#70575.
+ //
+ // rstStreamPingsBlocked is set after receiving a response to a PING frame
+ // bundled with an RST_STREAM (see pendingResets below), and cleared after
+ // receiving a HEADERS or DATA frame.
+ rstStreamPingsBlocked bool
+
+ // pendingResets is the number of RST_STREAM frames we have sent to the peer,
+ // without confirming that the peer has received them. When we send a RST_STREAM,
+ // we bundle it with a PING frame, unless a PING is already in flight. We count
+ // the reset stream against the connection's concurrency limit until we get
+ // a PING response. This limits the number of requests we'll try to send to a
+ // completely unresponsive connection.
+ pendingResets int
// reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
// Write to reqHeaderMu to lock it, read from it to unlock.
@@ -432,12 +474,12 @@ type clientStream struct {
sentHeaders bool
// owned by clientConnReadLoop:
- firstByte bool // got the first response byte
- pastHeaders bool // got first MetaHeadersFrame (actual headers)
- pastTrailers bool // got optional second MetaHeadersFrame (trailers)
- num1xx uint8 // number of 1xx responses seen
- readClosed bool // peer sent an END_STREAM flag
- readAborted bool // read loop reset the stream
+ firstByte bool // got the first response byte
+ pastHeaders bool // got first MetaHeadersFrame (actual headers)
+ pastTrailers bool // got optional second MetaHeadersFrame (trailers)
+ readClosed bool // peer sent an END_STREAM flag
+ readAborted bool // read loop reset the stream
+ totalHeaderSize int64 // total size of 1xx headers seen
trailer http.Header // accumulated trailers
resTrailer *http.Header // client's Response.Trailer
@@ -499,6 +541,7 @@ func (cs *clientStream) closeReqBodyLocked() {
}
type stickyErrWriter struct {
+ group synctestGroupInterface
conn net.Conn
timeout time.Duration
err *error
@@ -508,22 +551,9 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
- for {
- if sew.timeout != 0 {
- sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout))
- }
- nn, err := sew.conn.Write(p[n:])
- n += nn
- if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
- // Keep extending the deadline so long as we're making progress.
- continue
- }
- if sew.timeout != 0 {
- sew.conn.SetWriteDeadline(time.Time{})
- }
- *sew.err = err
- return n, err
- }
+ n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p)
+ *sew.err = err
+ return n, err
}
// noCachedConnError is the concrete type of ErrNoCachedConn, which
@@ -554,6 +584,8 @@ type RoundTripOpt struct {
// no cached connection is available, RoundTripOpt
// will return ErrNoCachedConn.
OnlyCachedConn bool
+
+ allowHTTP bool // allow http:// URLs
}
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -586,7 +618,14 @@ func authorityAddr(scheme string, authority string) (addr string) {
// RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
- if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
+ switch req.URL.Scheme {
+ case "https":
+ // Always okay.
+ case "http":
+ if !t.AllowHTTP && !opt.allowHTTP {
+ return nil, errors.New("http2: unencrypted HTTP/2 not enabled")
+ }
+ default:
return nil, errors.New("http2: unsupported scheme")
}
@@ -597,7 +636,7 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
return nil, err
}
- reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
+ reused := !atomic.CompareAndSwapUint32(&cc.atomicReused, 0, 1)
traceGotConn(req, cc, reused)
res, err := cc.RoundTrip(req)
if err != nil && retry <= 6 {
@@ -622,6 +661,22 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
}
}
}
+ if err == errClientConnNotEstablished {
+ // This ClientConn was created recently,
+ // this is the first request to use it,
+ // and the connection is closed and not usable.
+ //
+ // In this state, cc.idleTimer will remove the conn from the pool
+ // when it fires. Stop the timer and remove it here so future requests
+ // won't try to use this connection.
+ //
+ // If the timer has already fired and we're racing it, the redundant
+ // call to MarkDead is harmless.
+ if cc.idleTimer != nil {
+ cc.idleTimer.Stop()
+ }
+ t.connPool().MarkDead(cc)
+ }
if err != nil {
t.vlogf("RoundTrip failure: %v", err)
return nil, err
@@ -640,9 +695,10 @@ func (t *Transport) CloseIdleConnections() {
}
var (
- errClientConnClosed = errors.New("http2: client conn is closed")
- errClientConnUnusable = errors.New("http2: client conn not usable")
- errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
+ errClientConnClosed = errors.New("http2: client conn is closed")
+ errClientConnUnusable = errors.New("http2: client conn not usable")
+ errClientConnNotEstablished = errors.New("http2: client conn could not be established")
+ errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
)
// shouldRetryRequest is called by RoundTrip when a request fails to get
@@ -758,44 +814,38 @@ func (t *Transport) expectContinueTimeout() time.Duration {
return t.t1.ExpectContinueTimeout
}
-func (t *Transport) maxDecoderHeaderTableSize() uint32 {
- if v := t.MaxDecoderHeaderTableSize; v > 0 {
- return v
- }
- return initialHeaderTableSize
-}
-
-func (t *Transport) maxEncoderHeaderTableSize() uint32 {
- if v := t.MaxEncoderHeaderTableSize; v > 0 {
- return v
- }
- return initialHeaderTableSize
-}
-
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives())
}
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
+ conf := configFromTransport(t)
cc := &ClientConn{
- t: t,
- tconn: c,
- readerDone: make(chan struct{}),
- nextStreamID: 1,
- maxFrameSize: 16 << 10, // spec default
- initialWindowSize: 65535, // spec default
- maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
- peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
- streams: make(map[uint32]*clientStream),
- singleUse: singleUse,
- wantSettingsAck: true,
- pings: make(map[[8]byte]chan struct{}),
- reqHeaderMu: make(chan struct{}, 1),
- }
+ t: t,
+ tconn: c,
+ readerDone: make(chan struct{}),
+ nextStreamID: 1,
+ maxFrameSize: 16 << 10, // spec default
+ initialWindowSize: 65535, // spec default
+ initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
+ maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
+ peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
+ streams: make(map[uint32]*clientStream),
+ singleUse: singleUse,
+ seenSettingsChan: make(chan struct{}),
+ wantSettingsAck: true,
+ readIdleTimeout: conf.SendPingTimeout,
+ pingTimeout: conf.PingTimeout,
+ pings: make(map[[8]byte]chan struct{}),
+ reqHeaderMu: make(chan struct{}, 1),
+ lastActive: t.now(),
+ }
+ var group synctestGroupInterface
if t.transportTestHooks != nil {
t.markNewGoroutine()
t.transportTestHooks.newclientconn(cc)
c = cc.tconn
+ group = t.group
}
if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -807,24 +857,23 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
// TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding.
cc.bw = bufio.NewWriter(stickyErrWriter{
+ group: group,
conn: c,
- timeout: t.WriteByteTimeout,
+ timeout: conf.WriteByteTimeout,
err: &cc.werr,
})
cc.br = bufio.NewReader(c)
cc.fr = NewFramer(cc.bw, cc.br)
- if t.maxFrameReadSize() != 0 {
- cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize())
- }
+ cc.fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
if t.CountError != nil {
cc.fr.countError = t.CountError
}
- maxHeaderTableSize := t.maxDecoderHeaderTableSize()
+ maxHeaderTableSize := conf.MaxDecoderHeaderTableSize
cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil)
cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
cc.henc = hpack.NewEncoder(&cc.hbuf)
- cc.henc.SetMaxDynamicTableSizeLimit(t.maxEncoderHeaderTableSize())
+ cc.henc.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
cc.peerMaxHeaderTableSize = initialHeaderTableSize
if cs, ok := c.(connectionStater); ok {
@@ -834,11 +883,9 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
initialSettings := []Setting{
{ID: SettingEnablePush, Val: 0},
- {ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
- }
- if max := t.maxFrameReadSize(); max != 0 {
- initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: max})
+ {ID: SettingInitialWindowSize, Val: uint32(cc.initialStreamRecvWindowSize)},
}
+ initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: conf.MaxReadFrameSize})
if max := t.maxHeaderListSize(); max != 0 {
initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max})
}
@@ -848,8 +895,8 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
cc.bw.Write(clientPreface)
cc.fr.WriteSettings(initialSettings...)
- cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow)
- cc.inflow.init(transportDefaultConnFlow + initialWindowSize)
+ cc.fr.WriteWindowUpdate(0, uint32(conf.MaxUploadBufferPerConnection))
+ cc.inflow.init(conf.MaxUploadBufferPerConnection + initialWindowSize)
cc.bw.Flush()
if cc.werr != nil {
cc.Close()
@@ -867,7 +914,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
}
func (cc *ClientConn) healthCheck() {
- pingTimeout := cc.t.pingTimeout()
+ pingTimeout := cc.pingTimeout
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout)
@@ -995,7 +1042,7 @@ func (cc *ClientConn) State() ClientConnState {
return ClientConnState{
Closed: cc.closed,
Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil,
- StreamsActive: len(cc.streams),
+ StreamsActive: len(cc.streams) + cc.pendingResets,
StreamsReserved: cc.streamsReserved,
StreamsPending: cc.pendingRequests,
LastIdle: cc.lastIdle,
@@ -1027,16 +1074,40 @@ func (cc *ClientConn) idleStateLocked() (st clientConnIdleState) {
// writing it.
maxConcurrentOkay = true
} else {
- maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams)
+ // We can take a new request if the total of
+ // - active streams;
+ // - reservation slots for new streams; and
+ // - streams for which we have sent a RST_STREAM and a PING,
+ // but received no subsequent frame
+ // is less than the concurrency limit.
+ maxConcurrentOkay = cc.currentRequestCountLocked() < int(cc.maxConcurrentStreams)
}
st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay &&
!cc.doNotReuse &&
int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 &&
!cc.tooIdleLocked()
+
+ // If this connection has never been used for a request and is closed,
+ // then let it take a request (which will fail).
+ // If the conn was closed for idleness, we're racing the idle timer;
+ // don't try to use the conn. (Issue #70515.)
+ //
+ // This avoids a situation where an error early in a connection's lifetime
+ // goes unreported.
+ if cc.nextStreamID == 1 && cc.streamsReserved == 0 && cc.closed && !cc.closedOnIdle {
+ st.canTakeNewRequest = true
+ }
+
return
}
+// currentRequestCountLocked reports the number of concurrency slots currently in use,
+// including active streams, reserved slots, and reset streams waiting for acknowledgement.
+func (cc *ClientConn) currentRequestCountLocked() int {
+ return len(cc.streams) + cc.streamsReserved + cc.pendingResets
+}
+
func (cc *ClientConn) canTakeNewRequestLocked() bool {
st := cc.idleStateLocked()
return st.canTakeNewRequest
@@ -1049,7 +1120,7 @@ func (cc *ClientConn) tooIdleLocked() bool {
// times are compared based on their wall time. We don't want
// to reuse a connection that's been sitting idle during
// VM/laptop suspend if monotonic time was also frozen.
- return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && time.Since(cc.lastIdle.Round(0)) > cc.idleTimeout
+ return cc.idleTimeout != 0 && !cc.lastIdle.IsZero() && cc.t.timeSince(cc.lastIdle.Round(0)) > cc.idleTimeout
}
// onIdleTimeout is called from a time.AfterFunc goroutine. It will
@@ -1087,6 +1158,7 @@ func (cc *ClientConn) closeIfIdle() {
return
}
cc.closed = true
+ cc.closedOnIdle = true
nextID := cc.nextStreamID
// TODO: do clients send GOAWAY too? maybe? Just Close:
cc.mu.Unlock()
@@ -1203,23 +1275,6 @@ func (cc *ClientConn) closeForLostPing() {
// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests.
var errRequestCanceled = errors.New("net/http: request canceled")
-func commaSeparatedTrailers(req *http.Request) (string, error) {
- keys := make([]string, 0, len(req.Trailer))
- for k := range req.Trailer {
- k = canonicalHeader(k)
- switch k {
- case "Transfer-Encoding", "Trailer", "Content-Length":
- return "", fmt.Errorf("invalid Trailer key %q", k)
- }
- keys = append(keys, k)
- }
- if len(keys) > 0 {
- sort.Strings(keys)
- return strings.Join(keys, ","), nil
- }
- return "", nil
-}
-
func (cc *ClientConn) responseHeaderTimeout() time.Duration {
if cc.t.t1 != nil {
return cc.t.t1.ResponseHeaderTimeout
@@ -1231,22 +1286,6 @@ func (cc *ClientConn) responseHeaderTimeout() time.Duration {
return 0
}
-// checkConnHeaders checks whether req has any invalid connection-level headers.
-// per RFC 7540 section 8.1.2.2: Connection-Specific Header Fields.
-// Certain headers are special-cased as okay but not transmitted later.
-func checkConnHeaders(req *http.Request) error {
- if v := req.Header.Get("Upgrade"); v != "" {
- return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"])
- }
- if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
- return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
- }
- if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
- return fmt.Errorf("http2: invalid Connection request header: %q", vv)
- }
- return nil
-}
-
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
@@ -1292,25 +1331,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream))
donec: make(chan struct{}),
}
- // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
- if !cc.t.disableCompression() &&
- req.Header.Get("Accept-Encoding") == "" &&
- req.Header.Get("Range") == "" &&
- !cs.isHead {
- // Request gzip only, not deflate. Deflate is ambiguous and
- // not as universally supported anyway.
- // See: https://zlib.net/zlib_faq.html#faq39
- //
- // Note that we don't request this for HEAD requests,
- // due to a bug in nginx:
- // http://trac.nginx.org/nginx/ticket/358
- // https://golang.org/issue/5522
- //
- // We don't request gzip if the request is for a range, since
- // auto-decoding a portion of a gzipped document will just fail
- // anyway. See https://golang.org/issue/8923
- cs.requestedGzip = true
- }
+ cs.requestedGzip = httpcommon.IsRequestGzip(req.Method, req.Header, cc.t.disableCompression())
go cs.doRequest(req, streamf)
@@ -1411,6 +1432,8 @@ func (cs *clientStream) doRequest(req *http.Request, streamf func(*clientStream)
cs.cleanupWriteRequest(err)
}
+var errExtendedConnectNotSupported = errors.New("net/http: extended connect not supported by peer")
+
// writeRequest sends a request.
//
// It returns nil after the request is written, the response read,
@@ -1422,8 +1445,11 @@ func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStre
cc := cs.cc
ctx := cs.ctx
- if err := checkConnHeaders(req); err != nil {
- return err
+ // wait for setting frames to be received, a server can change this value later,
+ // but we just wait for the first settings frame
+ var isExtendedConnect bool
+ if req.Method == "CONNECT" && req.Header.Get(":protocol") != "" {
+ isExtendedConnect = true
}
// Acquire the new-request lock by writing to reqHeaderMu.
@@ -1432,6 +1458,18 @@ func (cs *clientStream) writeRequest(req *http.Request, streamf func(*clientStre
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
+ if isExtendedConnect {
+ select {
+ case <-cs.reqCancel:
+ return errRequestCanceled
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cc.seenSettingsChan:
+ if !cc.extendedConnectAllowed {
+ return errExtendedConnectNotSupported
+ }
+ }
+ }
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
@@ -1570,26 +1608,39 @@ func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error {
// we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is
// sent by writeRequestBody below, along with any Trailers,
// again in form HEADERS{1}, CONTINUATION{0,})
- trailers, err := commaSeparatedTrailers(req)
- if err != nil {
- return err
- }
- hasTrailers := trailers != ""
- contentLen := actualContentLength(req)
- hasBody := contentLen != 0
- hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
+ cc.hbuf.Reset()
+ res, err := encodeRequestHeaders(req, cs.requestedGzip, cc.peerMaxHeaderListSize, func(name, value string) {
+ cc.writeHeader(name, value)
+ })
if err != nil {
- return err
+ return fmt.Errorf("http2: %w", err)
}
+ hdrs := cc.hbuf.Bytes()
// Write the request.
- endStream := !hasBody && !hasTrailers
+ endStream := !res.HasBody && !res.HasTrailers
cs.sentHeaders = true
err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs)
traceWroteHeaders(cs.trace)
return err
}
+func encodeRequestHeaders(req *http.Request, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) {
+ return httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{
+ Request: httpcommon.Request{
+ Header: req.Header,
+ Trailer: req.Trailer,
+ URL: req.URL,
+ Host: req.Host,
+ Method: req.Method,
+ ActualContentLength: actualContentLength(req),
+ },
+ AddGzipHeader: addGzipHeader,
+ PeerMaxHeaderListSize: peerMaxHeaderListSize,
+ DefaultUserAgent: defaultUserAgent,
+ }, headerf)
+}
+
// cleanupWriteRequest performs post-request tasks.
//
// If err (the result of writeRequest) is non-nil and the stream is not closed,
@@ -1613,6 +1664,7 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
cs.reqBodyClosed = make(chan struct{})
}
bodyClosed := cs.reqBodyClosed
+ closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
cc.mu.Unlock()
if mustCloseBody {
cs.reqBody.Close()
@@ -1637,16 +1689,44 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
if cs.sentHeaders {
if se, ok := err.(StreamError); ok {
if se.Cause != errFromPeer {
- cc.writeStreamReset(cs.ID, se.Code, err)
+ cc.writeStreamReset(cs.ID, se.Code, false, err)
}
} else {
- cc.writeStreamReset(cs.ID, ErrCodeCancel, err)
+ // We're cancelling an in-flight request.
+ //
+ // This could be due to the server becoming unresponsive.
+ // To avoid sending too many requests on a dead connection,
+ // we let the request continue to consume a concurrency slot
+ // until we can confirm the server is still responding.
+ // We do this by sending a PING frame along with the RST_STREAM
+ // (unless a ping is already in flight).
+ //
+ // For simplicity, we don't bother tracking the PING payload:
+ // We reset cc.pendingResets any time we receive a PING ACK.
+ //
+ // We skip this if the conn is going to be closed on idle,
+ // because it's short lived and will probably be closed before
+ // we get the ping response.
+ ping := false
+ if !closeOnIdle {
+ cc.mu.Lock()
+ // rstStreamPingsBlocked works around a gRPC behavior:
+ // see comment on the field for details.
+ if !cc.rstStreamPingsBlocked {
+ if cc.pendingResets == 0 {
+ ping = true
+ }
+ cc.pendingResets++
+ }
+ cc.mu.Unlock()
+ }
+ cc.writeStreamReset(cs.ID, ErrCodeCancel, ping, err)
}
}
cs.bufPipe.CloseWithError(err) // no-op if already closed
} else {
if cs.sentHeaders && !cs.sentEndStream {
- cc.writeStreamReset(cs.ID, ErrCodeNo, nil)
+ cc.writeStreamReset(cs.ID, ErrCodeNo, false, nil)
}
cs.bufPipe.CloseWithError(errRequestCanceled)
}
@@ -1668,12 +1748,17 @@ func (cs *clientStream) cleanupWriteRequest(err error) {
// Must hold cc.mu.
func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error {
for {
- cc.lastActive = time.Now()
+ if cc.closed && cc.nextStreamID == 1 && cc.streamsReserved == 0 {
+ // This is the very first request sent to this connection.
+ // Return a fatal error which aborts the retry loop.
+ return errClientConnNotEstablished
+ }
+ cc.lastActive = cc.t.now()
if cc.closed || !cc.canTakeNewRequestLocked() {
return errClientConnUnusable
}
cc.lastIdle = time.Time{}
- if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) {
+ if cc.currentRequestCountLocked() < int(cc.maxConcurrentStreams) {
return nil
}
cc.pendingRequests++
@@ -1943,214 +2028,6 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error)
}
}
-func validateHeaders(hdrs http.Header) string {
- for k, vv := range hdrs {
- if !httpguts.ValidHeaderFieldName(k) {
- return fmt.Sprintf("name %q", k)
- }
- for _, v := range vv {
- if !httpguts.ValidHeaderFieldValue(v) {
- // Don't include the value in the error,
- // because it may be sensitive.
- return fmt.Sprintf("value for header %q", k)
- }
- }
- }
- return ""
-}
-
-var errNilRequestURL = errors.New("http2: Request.URI is nil")
-
-// requires cc.wmu be held.
-func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
- cc.hbuf.Reset()
- if req.URL == nil {
- return nil, errNilRequestURL
- }
-
- host := req.Host
- if host == "" {
- host = req.URL.Host
- }
- host, err := httpguts.PunycodeHostPort(host)
- if err != nil {
- return nil, err
- }
- if !httpguts.ValidHostHeader(host) {
- return nil, errors.New("http2: invalid Host header")
- }
-
- var path string
- if req.Method != "CONNECT" {
- path = req.URL.RequestURI()
- if !validPseudoPath(path) {
- orig := path
- path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
- if !validPseudoPath(path) {
- if req.URL.Opaque != "" {
- return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
- } else {
- return nil, fmt.Errorf("invalid request :path %q", orig)
- }
- }
- }
- }
-
- // Check for any invalid headers+trailers and return an error before we
- // potentially pollute our hpack state. (We want to be able to
- // continue to reuse the hpack encoder for future requests)
- if err := validateHeaders(req.Header); err != "" {
- return nil, fmt.Errorf("invalid HTTP header %s", err)
- }
- if err := validateHeaders(req.Trailer); err != "" {
- return nil, fmt.Errorf("invalid HTTP trailer %s", err)
- }
-
- enumerateHeaders := func(f func(name, value string)) {
- // 8.1.2.3 Request Pseudo-Header Fields
- // The :path pseudo-header field includes the path and query parts of the
- // target URI (the path-absolute production and optionally a '?' character
- // followed by the query production, see Sections 3.3 and 3.4 of
- // [RFC3986]).
- f(":authority", host)
- m := req.Method
- if m == "" {
- m = http.MethodGet
- }
- f(":method", m)
- if req.Method != "CONNECT" {
- f(":path", path)
- f(":scheme", req.URL.Scheme)
- }
- if trailers != "" {
- f("trailer", trailers)
- }
-
- var didUA bool
- for k, vv := range req.Header {
- if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
- // Host is :authority, already sent.
- // Content-Length is automatic, set below.
- continue
- } else if asciiEqualFold(k, "connection") ||
- asciiEqualFold(k, "proxy-connection") ||
- asciiEqualFold(k, "transfer-encoding") ||
- asciiEqualFold(k, "upgrade") ||
- asciiEqualFold(k, "keep-alive") {
- // Per 8.1.2.2 Connection-Specific Header
- // Fields, don't send connection-specific
- // fields. We have already checked if any
- // are error-worthy so just ignore the rest.
- continue
- } else if asciiEqualFold(k, "user-agent") {
- // Match Go's http1 behavior: at most one
- // User-Agent. If set to nil or empty string,
- // then omit it. Otherwise if not mentioned,
- // include the default (below).
- didUA = true
- if len(vv) < 1 {
- continue
- }
- vv = vv[:1]
- if vv[0] == "" {
- continue
- }
- } else if asciiEqualFold(k, "cookie") {
- // Per 8.1.2.5 To allow for better compression efficiency, the
- // Cookie header field MAY be split into separate header fields,
- // each with one or more cookie-pairs.
- for _, v := range vv {
- for {
- p := strings.IndexByte(v, ';')
- if p < 0 {
- break
- }
- f("cookie", v[:p])
- p++
- // strip space after semicolon if any.
- for p+1 <= len(v) && v[p] == ' ' {
- p++
- }
- v = v[p:]
- }
- if len(v) > 0 {
- f("cookie", v)
- }
- }
- continue
- }
-
- for _, v := range vv {
- f(k, v)
- }
- }
- if shouldSendReqContentLength(req.Method, contentLength) {
- f("content-length", strconv.FormatInt(contentLength, 10))
- }
- if addGzipHeader {
- f("accept-encoding", "gzip")
- }
- if !didUA {
- f("user-agent", defaultUserAgent)
- }
- }
-
- // Do a first pass over the headers counting bytes to ensure
- // we don't exceed cc.peerMaxHeaderListSize. This is done as a
- // separate pass before encoding the headers to prevent
- // modifying the hpack state.
- hlSize := uint64(0)
- enumerateHeaders(func(name, value string) {
- hf := hpack.HeaderField{Name: name, Value: value}
- hlSize += uint64(hf.Size())
- })
-
- if hlSize > cc.peerMaxHeaderListSize {
- return nil, errRequestHeaderListSize
- }
-
- trace := httptrace.ContextClientTrace(req.Context())
- traceHeaders := traceHasWroteHeaderField(trace)
-
- // Header list size is ok. Write the headers.
- enumerateHeaders(func(name, value string) {
- name, ascii := lowerHeader(name)
- if !ascii {
- // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
- // field names have to be ASCII characters (just as in HTTP/1.x).
- return
- }
- cc.writeHeader(name, value)
- if traceHeaders {
- traceWroteHeaderField(trace, name, value)
- }
- })
-
- return cc.hbuf.Bytes(), nil
-}
-
-// shouldSendReqContentLength reports whether the http2.Transport should send
-// a "content-length" request header. This logic is basically a copy of the net/http
-// transferWriter.shouldSendContentLength.
-// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
-// -1 means unknown.
-func shouldSendReqContentLength(method string, contentLength int64) bool {
- if contentLength > 0 {
- return true
- }
- if contentLength < 0 {
- return false
- }
- // For zero bodies, whether we send a content-length depends on the method.
- // It also kinda doesn't matter for http2 either way, with END_STREAM.
- switch method {
- case "POST", "PUT", "PATCH":
- return true
- default:
- return false
- }
-}
-
// requires cc.wmu be held.
func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) {
cc.hbuf.Reset()
@@ -2167,7 +2044,7 @@ func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) {
}
for k, vv := range trailer {
- lowKey, ascii := lowerHeader(k)
+ lowKey, ascii := httpcommon.LowerHeader(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
@@ -2199,7 +2076,7 @@ type resAndError struct {
func (cc *ClientConn) addStreamLocked(cs *clientStream) {
cs.flow.add(int32(cc.initialWindowSize))
cs.flow.setConnFlow(&cc.flow)
- cs.inflow.init(transportDefaultStreamFlow)
+ cs.inflow.init(cc.initialStreamRecvWindowSize)
cs.ID = cc.nextStreamID
cc.nextStreamID += 2
cc.streams[cs.ID] = cs
@@ -2215,10 +2092,10 @@ func (cc *ClientConn) forgetStreamID(id uint32) {
if len(cc.streams) != slen-1 {
panic("forgetting unknown stream id")
}
- cc.lastActive = time.Now()
+ cc.lastActive = cc.t.now()
if len(cc.streams) == 0 && cc.idleTimer != nil {
cc.idleTimer.Reset(cc.idleTimeout)
- cc.lastIdle = time.Now()
+ cc.lastIdle = cc.t.now()
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
@@ -2278,7 +2155,6 @@ func isEOFOrNetReadError(err error) bool {
func (rl *clientConnReadLoop) cleanup() {
cc := rl.cc
- cc.t.connPool().MarkDead(cc)
defer cc.closeConn()
defer close(cc.readerDone)
@@ -2302,6 +2178,27 @@ func (rl *clientConnReadLoop) cleanup() {
}
cc.closed = true
+ // If the connection has never been used, and has been open for only a short time,
+ // leave it in the connection pool for a little while.
+ //
+ // This avoids a situation where new connections are constantly created,
+ // added to the pool, fail, and are removed from the pool, without any error
+ // being surfaced to the user.
+ unusedWaitTime := 5 * time.Second
+ if cc.idleTimeout > 0 && unusedWaitTime > cc.idleTimeout {
+ unusedWaitTime = cc.idleTimeout
+ }
+ idleTime := cc.t.now().Sub(cc.lastActive)
+ if atomic.LoadUint32(&cc.atomicReused) == 0 && idleTime < unusedWaitTime && !cc.closedOnIdle {
+ cc.idleTimer = cc.t.afterFunc(unusedWaitTime-idleTime, func() {
+ cc.t.connPool().MarkDead(cc)
+ })
+ } else {
+ cc.mu.Unlock() // avoid any deadlocks in MarkDead
+ cc.t.connPool().MarkDead(cc)
+ cc.mu.Lock()
+ }
+
for _, cs := range cc.streams {
select {
case <-cs.peerClosed:
@@ -2313,6 +2210,13 @@ func (rl *clientConnReadLoop) cleanup() {
}
cc.cond.Broadcast()
cc.mu.Unlock()
+
+ if !cc.seenSettings {
+ // If we have a pending request that wants extended CONNECT,
+ // let it continue and fail with the connection error.
+ cc.extendedConnectAllowed = true
+ close(cc.seenSettingsChan)
+ }
}
// countReadFrameError calls Transport.CountError with a string
@@ -2345,7 +2249,7 @@ func (cc *ClientConn) countReadFrameError(err error) {
func (rl *clientConnReadLoop) run() error {
cc := rl.cc
gotSettings := false
- readIdleTimeout := cc.t.ReadIdleTimeout
+ readIdleTimeout := cc.readIdleTimeout
var t timer
if readIdleTimeout != 0 {
t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck)
@@ -2359,7 +2263,7 @@ func (rl *clientConnReadLoop) run() error {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
if se, ok := err.(StreamError); ok {
- if cs := rl.streamByID(se.StreamID); cs != nil {
+ if cs := rl.streamByID(se.StreamID, notHeaderOrDataFrame); cs != nil {
if se.Cause == nil {
se.Cause = cc.fr.errDetail
}
@@ -2411,7 +2315,7 @@ func (rl *clientConnReadLoop) run() error {
}
func (rl *clientConnReadLoop) processHeaders(f *MetaHeadersFrame) error {
- cs := rl.streamByID(f.StreamID)
+ cs := rl.streamByID(f.StreamID, headerOrDataFrame)
if cs == nil {
// We'd get here if we canceled a request while the
// server had its response still in flight. So if this
@@ -2499,7 +2403,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range regularFields {
- key := canonicalHeader(hf.Name)
+ key := httpcommon.CanonicalHeader(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
@@ -2507,7 +2411,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
- t[canonicalHeader(v)] = nil
+ t[httpcommon.CanonicalHeader(v)] = nil
})
} else {
vv := header[key]
@@ -2529,15 +2433,34 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
if f.StreamEnded() {
return nil, errors.New("1xx informational response with END_STREAM flag")
}
- cs.num1xx++
- const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http
- if cs.num1xx > max1xxResponses {
- return nil, errors.New("http2: too many 1xx informational responses")
- }
if fn := cs.get1xxTraceFunc(); fn != nil {
+ // If the 1xx response is being delivered to the user,
+ // then they're responsible for limiting the number
+ // of responses.
if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil {
return nil, err
}
+ } else {
+ // If the user didn't examine the 1xx response, then we
+ // limit the size of all 1xx headers.
+ //
+ // This differs a bit from the HTTP/1 implementation, which
+ // limits the size of all 1xx headers plus the final response.
+ // Use the larger limit of MaxHeaderListSize and
+ // net/http.Transport.MaxResponseHeaderBytes.
+ limit := int64(cs.cc.t.maxHeaderListSize())
+ if t1 := cs.cc.t.t1; t1 != nil && t1.MaxResponseHeaderBytes > limit {
+ limit = t1.MaxResponseHeaderBytes
+ }
+ for _, h := range f.Fields {
+ cs.totalHeaderSize += int64(h.Size())
+ }
+ if cs.totalHeaderSize > limit {
+ if VerboseLogs {
+ log.Printf("http2: 1xx informational responses too large")
+ }
+ return nil, errors.New("header list too large")
+ }
}
if statusCode == 100 {
traceGot100Continue(cs.trace)
@@ -2612,7 +2535,7 @@ func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFr
trailer := make(http.Header)
for _, hf := range f.RegularFields() {
- key := canonicalHeader(hf.Name)
+ key := httpcommon.CanonicalHeader(hf.Name)
trailer[key] = append(trailer[key], hf.Value)
}
cs.trailer = trailer
@@ -2721,7 +2644,7 @@ func (b transportResponseBody) Close() error {
func (rl *clientConnReadLoop) processData(f *DataFrame) error {
cc := rl.cc
- cs := rl.streamByID(f.StreamID)
+ cs := rl.streamByID(f.StreamID, headerOrDataFrame)
data := f.Data()
if cs == nil {
cc.mu.Lock()
@@ -2856,9 +2779,22 @@ func (rl *clientConnReadLoop) endStreamError(cs *clientStream, err error) {
cs.abortStream(err)
}
-func (rl *clientConnReadLoop) streamByID(id uint32) *clientStream {
+// Constants passed to streamByID for documentation purposes.
+const (
+ headerOrDataFrame = true
+ notHeaderOrDataFrame = false
+)
+
+// streamByID returns the stream with the given id, or nil if no stream has that id.
+// If headerOrData is true, it clears rst.StreamPingsBlocked.
+func (rl *clientConnReadLoop) streamByID(id uint32, headerOrData bool) *clientStream {
rl.cc.mu.Lock()
defer rl.cc.mu.Unlock()
+ if headerOrData {
+ // Work around an unfortunate gRPC behavior.
+ // See comment on ClientConn.rstStreamPingsBlocked for details.
+ rl.cc.rstStreamPingsBlocked = false
+ }
cs := rl.cc.streams[id]
if cs != nil && !cs.readAborted {
return cs
@@ -2952,6 +2888,21 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
case SettingHeaderTableSize:
cc.henc.SetMaxDynamicTableSize(s.Val)
cc.peerMaxHeaderTableSize = s.Val
+ case SettingEnableConnectProtocol:
+ if err := s.Valid(); err != nil {
+ return err
+ }
+ // If the peer wants to send us SETTINGS_ENABLE_CONNECT_PROTOCOL,
+ // we require that it do so in the first SETTINGS frame.
+ //
+ // When we attempt to use extended CONNECT, we wait for the first
+ // SETTINGS frame to see if the server supports it. If we let the
+ // server enable the feature with a later SETTINGS frame, then
+ // users will see inconsistent results depending on whether we've
+ // seen that frame or not.
+ if !cc.seenSettings {
+ cc.extendedConnectAllowed = s.Val == 1
+ }
default:
cc.vlogf("Unhandled Setting: %v", s)
}
@@ -2969,6 +2920,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
// connection can establish to our default.
cc.maxConcurrentStreams = defaultMaxConcurrentStreams
}
+ close(cc.seenSettingsChan)
cc.seenSettings = true
}
@@ -2977,7 +2929,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
cc := rl.cc
- cs := rl.streamByID(f.StreamID)
+ cs := rl.streamByID(f.StreamID, notHeaderOrDataFrame)
if f.StreamID != 0 && cs == nil {
return nil
}
@@ -3006,7 +2958,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error {
}
func (rl *clientConnReadLoop) processResetStream(f *RSTStreamFrame) error {
- cs := rl.streamByID(f.StreamID)
+ cs := rl.streamByID(f.StreamID, notHeaderOrDataFrame)
if cs == nil {
// TODO: return error if server tries to RST_STREAM an idle stream
return nil
@@ -3081,6 +3033,12 @@ func (rl *clientConnReadLoop) processPing(f *PingFrame) error {
close(c)
delete(cc.pings, f.Data)
}
+ if cc.pendingResets > 0 {
+ // See clientStream.cleanupWriteRequest.
+ cc.pendingResets = 0
+ cc.rstStreamPingsBlocked = true
+ cc.cond.Broadcast()
+ }
return nil
}
cc := rl.cc
@@ -3103,20 +3061,27 @@ func (rl *clientConnReadLoop) processPushPromise(f *PushPromiseFrame) error {
return ConnectionError(ErrCodeProtocol)
}
-func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, err error) {
+// writeStreamReset sends a RST_STREAM frame.
+// When ping is true, it also sends a PING frame with a random payload.
+func (cc *ClientConn) writeStreamReset(streamID uint32, code ErrCode, ping bool, err error) {
// TODO: map err to more interesting error codes, once the
// HTTP community comes up with some. But currently for
// RST_STREAM there's no equivalent to GOAWAY frame's debug
// data, and the error codes are all pretty vague ("cancel").
cc.wmu.Lock()
cc.fr.WriteRSTStream(streamID, code)
+ if ping {
+ var payload [8]byte
+ rand.Read(payload[:])
+ cc.fr.WritePing(false, payload)
+ }
cc.bw.Flush()
cc.wmu.Unlock()
}
var (
errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
- errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit")
+ errRequestHeaderListSize = httpcommon.ErrRequestHeaderListSize
)
func (cc *ClientConn) logf(format string, args ...interface{}) {
@@ -3263,7 +3228,7 @@ func traceGotConn(req *http.Request, cc *ClientConn, reused bool) {
cc.mu.Lock()
ci.WasIdle = len(cc.streams) == 0 && reused
if ci.WasIdle && !cc.lastActive.IsZero() {
- ci.IdleTime = time.Since(cc.lastActive)
+ ci.IdleTime = cc.t.timeSince(cc.lastActive)
}
cc.mu.Unlock()
@@ -3300,16 +3265,6 @@ func traceFirstResponseByte(trace *httptrace.ClientTrace) {
}
}
-func traceHasWroteHeaderField(trace *httptrace.ClientTrace) bool {
- return trace != nil && trace.WroteHeaderField != nil
-}
-
-func traceWroteHeaderField(trace *httptrace.ClientTrace, k, v string) {
- if trace != nil && trace.WroteHeaderField != nil {
- trace.WroteHeaderField(k, []string{v})
- }
-}
-
func traceGot1xxResponseFunc(trace *httptrace.ClientTrace) func(int, textproto.MIMEHeader) error {
if trace != nil {
return trace.Got1xxResponse
diff --git a/http2/transport_test.go b/http2/transport_test.go
index 498e27932c..1eeb76e06e 100644
--- a/http2/transport_test.go
+++ b/http2/transport_test.go
@@ -1420,7 +1420,7 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
res0.Body.Close()
res, err := tr.RoundTrip(req)
- if err != wantErr {
+ if !errors.Is(err, wantErr) {
if res != nil {
res.Body.Close()
}
@@ -1443,26 +1443,14 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) {
}
}
headerListSizeForRequest := func(req *http.Request) (size uint64) {
- contentLen := actualContentLength(req)
- trailers, err := commaSeparatedTrailers(req)
- if err != nil {
- t.Fatalf("headerListSizeForRequest: %v", err)
- }
- cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
- cc.henc = hpack.NewEncoder(&cc.hbuf)
- cc.mu.Lock()
- hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
- cc.mu.Unlock()
- if err != nil {
- t.Fatalf("headerListSizeForRequest: %v", err)
- }
- hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
+ const addGzipHeader = true
+ const peerMaxHeaderListSize = 0xffffffffffffffff
+ _, err := encodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) {
+ hf := hpack.HeaderField{Name: name, Value: value}
size += uint64(hf.Size())
})
- if len(hdrs) > 0 {
- if _, err := hpackDec.Write(hdrs); err != nil {
- t.Fatalf("headerListSizeForRequest: %v", err)
- }
+ if err != nil {
+ t.Fatal(err)
}
return size
}
@@ -2559,6 +2547,9 @@ func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
}
return true
},
+ func(f *PingFrame) bool {
+ return true
+ },
func(f *WindowUpdateFrame) bool {
if !oneDataFrame && !sentAdditionalData {
t.Fatalf("Got WindowUpdateFrame, don't expect one yet")
@@ -2850,11 +2841,15 @@ func TestTransportRequestPathPseudo(t *testing.T) {
},
}
for i, tt := range tests {
- cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
- cc.henc = hpack.NewEncoder(&cc.hbuf)
- cc.mu.Lock()
- hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
- cc.mu.Unlock()
+ hbuf := &bytes.Buffer{}
+ henc := hpack.NewEncoder(hbuf)
+
+ const addGzipHeader = false
+ const peerMaxHeaderListSize = 0xffffffffffffffff
+ _, err := encodeRequestHeaders(tt.req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) {
+ henc.WriteField(hpack.HeaderField{Name: name, Value: value})
+ })
+ hdrs := hbuf.Bytes()
var got result
hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
if f.Name == ":path" {
@@ -5421,3 +5416,522 @@ func TestIssue67671(t *testing.T) {
res.Body.Close()
}
}
+
+func TestTransport1xxLimits(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ opt any
+ ctxfn func(context.Context) context.Context
+ hcount int
+ limited bool
+ }{{
+ name: "default",
+ hcount: 10,
+ limited: false,
+ }, {
+ name: "MaxHeaderListSize",
+ opt: func(tr *Transport) {
+ tr.MaxHeaderListSize = 10000
+ },
+ hcount: 10,
+ limited: true,
+ }, {
+ name: "MaxResponseHeaderBytes",
+ opt: func(tr *http.Transport) {
+ tr.MaxResponseHeaderBytes = 10000
+ },
+ hcount: 10,
+ limited: true,
+ }, {
+ name: "limit by client trace",
+ ctxfn: func(ctx context.Context) context.Context {
+ count := 0
+ return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ count++
+ if count >= 10 {
+ return errors.New("too many 1xx")
+ }
+ return nil
+ },
+ })
+ },
+ hcount: 10,
+ limited: true,
+ }, {
+ name: "limit disabled by client trace",
+ opt: func(tr *Transport) {
+ tr.MaxHeaderListSize = 10000
+ },
+ ctxfn: func(ctx context.Context) context.Context {
+ return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ return nil
+ },
+ })
+ },
+ hcount: 20,
+ limited: false,
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ tc := newTestClientConn(t, test.opt)
+ tc.greet()
+
+ ctx := context.Background()
+ if test.ctxfn != nil {
+ ctx = test.ctxfn(ctx)
+ }
+ req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+
+ for i := 0; i < test.hcount; i++ {
+ if fr, err := tc.fr.ReadFrame(); err != os.ErrDeadlineExceeded {
+ t.Fatalf("after writing %v 1xx headers: read %v, %v; want idle", i, fr, err)
+ }
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ EndStream: false,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "103",
+ "x-field", strings.Repeat("a", 1000),
+ ),
+ })
+ }
+ if test.limited {
+ tc.wantFrameType(FrameRSTStream)
+ } else {
+ tc.wantIdle()
+ }
+ })
+ }
+}
+
+func TestTransportSendPingWithReset(t *testing.T) {
+ tc := newTestClientConn(t, func(tr *Transport) {
+ tr.StrictMaxConcurrentStreams = true
+ })
+
+ const maxConcurrent = 3
+ tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+
+ // Start several requests.
+ var rts []*testRoundTrip
+ for i := 0; i < maxConcurrent+1; i++ {
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tc.roundTrip(req)
+ if i >= maxConcurrent {
+ tc.wantIdle()
+ continue
+ }
+ tc.wantFrameType(FrameHeaders)
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: rt.streamID(),
+ EndHeaders: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
+ rts = append(rts, rt)
+ }
+
+ // Cancel one request. We send a PING frame along with the RST_STREAM.
+ rts[0].response().Body.Close()
+ tc.wantRSTStream(rts[0].streamID(), ErrCodeCancel)
+ pf := readFrame[*PingFrame](t, tc)
+ tc.wantIdle()
+
+ // Cancel another request. No PING frame, since one is in flight.
+ rts[1].response().Body.Close()
+ tc.wantRSTStream(rts[1].streamID(), ErrCodeCancel)
+ tc.wantIdle()
+
+ // Respond to the PING.
+ // This finalizes the previous resets, and allows the pending request to be sent.
+ tc.writePing(true, pf.Data)
+ tc.wantFrameType(FrameHeaders)
+ tc.wantIdle()
+
+ // Receive a byte of data for the remaining stream, which resets our ability
+ // to send pings (see comment on ClientConn.rstStreamPingsBlocked).
+ tc.writeData(rts[2].streamID(), false, []byte{0})
+
+ // Cancel the last request. We send another PING, since none are in flight.
+ rts[2].response().Body.Close()
+ tc.wantRSTStream(rts[2].streamID(), ErrCodeCancel)
+ tc.wantFrameType(FramePing)
+ tc.wantIdle()
+}
+
+// Issue #70505: gRPC gets upset if we send more than 2 pings per HEADERS/DATA frame
+// sent by the server.
+func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ makeAndResetRequest := func() {
+ t.Helper()
+ ctx, cancel := context.WithCancel(context.Background())
+ req := must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
+ rt := tc.roundTrip(req)
+ tc.wantFrameType(FrameHeaders)
+ cancel()
+ tc.wantRSTStream(rt.streamID(), ErrCodeCancel) // client sends RST_STREAM
+ }
+
+ // Create a request and cancel it.
+ // The client sends a PING frame along with the reset.
+ makeAndResetRequest()
+ pf1 := readFrame[*PingFrame](t, tc) // client sends PING
+
+ // Create another request and cancel it.
+ // We do not send a PING frame along with the reset,
+ // because we haven't received a HEADERS or DATA frame from the server
+ // since the last PING we sent.
+ makeAndResetRequest()
+
+ // Server belatedly responds to request 1.
+ // The server has not responded to our first PING yet.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+
+ // Create yet another request and cancel it.
+ // We still do not send a PING frame along with the reset.
+ // We've received a HEADERS frame, but it came before the response to the PING.
+ makeAndResetRequest()
+
+ // The server responds to our PING.
+ tc.writePing(true, pf1.Data)
+
+ // Create yet another request and cancel it.
+ // Still no PING frame; we got a response to the previous one,
+ // but no HEADERS or DATA.
+ makeAndResetRequest()
+
+ // Server belatedly responds to the second request.
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 3,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+
+ // One more request.
+ // This time we send a PING frame.
+ makeAndResetRequest()
+ tc.wantFrameType(FramePing)
+}
+
+func TestTransportConnBecomesUnresponsive(t *testing.T) {
+ // We send a number of requests in series to an unresponsive connection.
+ // Each request is canceled or times out without a response.
+ // Eventually, we open a new connection rather than trying to use the old one.
+ tt := newTestTransport(t)
+
+ const maxConcurrent = 3
+
+ t.Logf("first request opens a new connection and succeeds")
+ req1 := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt1 := tt.roundTrip(req1)
+ tc1 := tt.getConn()
+ tc1.wantFrameType(FrameSettings)
+ tc1.wantFrameType(FrameWindowUpdate)
+ hf1 := readFrame[*HeadersFrame](t, tc1)
+ tc1.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+ tc1.wantFrameType(FrameSettings) // ack
+ tc1.writeHeaders(HeadersFrameParam{
+ StreamID: hf1.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc1.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt1.wantStatus(200)
+ rt1.response().Body.Close()
+
+ // Send more requests.
+ // None receive a response.
+ // Each is canceled.
+ for i := 0; i < maxConcurrent; i++ {
+ t.Logf("request %v receives no response and is canceled", i)
+ ctx, cancel := context.WithCancel(context.Background())
+ req := must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
+ tt.roundTrip(req)
+ if tt.hasConn() {
+ t.Fatalf("new connection created; expect existing conn to be reused")
+ }
+ tc1.wantFrameType(FrameHeaders)
+ cancel()
+ tc1.wantFrameType(FrameRSTStream)
+ if i == 0 {
+ tc1.wantFrameType(FramePing)
+ }
+ tc1.wantIdle()
+ }
+
+ // The conn has hit its concurrency limit.
+ // The next request is sent on a new conn.
+ req2 := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt2 := tt.roundTrip(req2)
+ tc2 := tt.getConn()
+ tc2.wantFrameType(FrameSettings)
+ tc2.wantFrameType(FrameWindowUpdate)
+ hf := readFrame[*HeadersFrame](t, tc2)
+ tc2.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
+ tc2.wantFrameType(FrameSettings) // ack
+ tc2.writeHeaders(HeadersFrameParam{
+ StreamID: hf.StreamID,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc2.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt2.wantStatus(200)
+ rt2.response().Body.Close()
+}
+
+// Test that the Transport can use a conn provided to it by a TLSNextProto hook.
+func TestTransportTLSNextProtoConnOK(t *testing.T) {
+ t1 := &http.Transport{}
+ t2, _ := ConfigureTransports(t1)
+ tt := newTestTransport(t, t2)
+
+ // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook.
+ cli, _ := synctestNetPipe(tt.group)
+ cliTLS := tls.Client(cli, tlsConfigInsecure)
+ go func() {
+ tt.group.Join()
+ t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
+ }()
+ tt.sync()
+ tc := tt.getConn()
+ tc.greet()
+
+ // Send a request on the Transport.
+ // It uses the conn we provided.
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tt.roundTrip(req)
+ tc.wantHeaders(wantHeader{
+ streamID: 1,
+ endStream: true,
+ header: http.Header{
+ ":authority": []string{"dummy.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{"/"},
+ },
+ })
+ tc.writeHeaders(HeadersFrameParam{
+ StreamID: 1,
+ EndHeaders: true,
+ EndStream: true,
+ BlockFragment: tc.makeHeaderBlockFragment(
+ ":status", "200",
+ ),
+ })
+ rt.wantStatus(200)
+ rt.wantBody(nil)
+}
+
+// Test the case where a conn provided via a TLSNextProto hook immediately encounters an error.
+func TestTransportTLSNextProtoConnImmediateFailureUsed(t *testing.T) {
+ t1 := &http.Transport{}
+ t2, _ := ConfigureTransports(t1)
+ tt := newTestTransport(t, t2)
+
+ // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook.
+ cli, _ := synctestNetPipe(tt.group)
+ cliTLS := tls.Client(cli, tlsConfigInsecure)
+ go func() {
+ tt.group.Join()
+ t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
+ }()
+ tt.sync()
+ tc := tt.getConn()
+
+ // The connection encounters an error before we send a request that uses it.
+ tc.closeWrite()
+
+ // Send a request on the Transport.
+ //
+ // It should fail, because we have no usable connections, but not with ErrNoCachedConn.
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tt.roundTrip(req)
+ if err := rt.err(); err == nil || errors.Is(err, ErrNoCachedConn) {
+ t.Fatalf("RoundTrip with broken conn: got %v, want an error other than ErrNoCachedConn", err)
+ }
+
+ // Send the request again.
+ // This time it should fail with ErrNoCachedConn,
+ // because the dead conn has been removed from the pool.
+ rt = tt.roundTrip(req)
+ if err := rt.err(); !errors.Is(err, ErrNoCachedConn) {
+ t.Fatalf("RoundTrip after broken conn is used: got %v, want ErrNoCachedConn", err)
+ }
+}
+
+// Test the case where a conn provided via a TLSNextProto hook is closed for idleness
+// before we use it.
+func TestTransportTLSNextProtoConnIdleTimoutBeforeUse(t *testing.T) {
+ t1 := &http.Transport{
+ IdleConnTimeout: 1 * time.Second,
+ }
+ t2, _ := ConfigureTransports(t1)
+ tt := newTestTransport(t, t2)
+
+ // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook.
+ cli, _ := synctestNetPipe(tt.group)
+ cliTLS := tls.Client(cli, tlsConfigInsecure)
+ go func() {
+ tt.group.Join()
+ t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
+ }()
+ tt.sync()
+ tc := tt.getConn()
+
+ // The connection encounters an error before we send a request that uses it.
+ tc.advance(2 * time.Second)
+
+ // Send a request on the Transport.
+ //
+ // It should fail with ErrNoCachedConn.
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tt.roundTrip(req)
+ if err := rt.err(); !errors.Is(err, ErrNoCachedConn) {
+ t.Fatalf("RoundTrip with conn closed for idleness: got %v, want ErrNoCachedConn", err)
+ }
+}
+
+// Test the case where a conn provided via a TLSNextProto hook immediately encounters an error,
+// but no requests are sent which would use the bad connection.
+func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) {
+ t1 := &http.Transport{}
+ t2, _ := ConfigureTransports(t1)
+ tt := newTestTransport(t, t2)
+
+ // Create a new, fake connection and pass it to the Transport via the TLSNextProto hook.
+ cli, _ := synctestNetPipe(tt.group)
+ cliTLS := tls.Client(cli, tlsConfigInsecure)
+ go func() {
+ tt.group.Join()
+ t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
+ }()
+ tt.sync()
+ tc := tt.getConn()
+
+ // The connection encounters an error before we send a request that uses it.
+ tc.closeWrite()
+
+ // Some time passes.
+ // The dead connection is removed from the pool.
+ tc.advance(10 * time.Second)
+
+ // Send a request on the Transport.
+ //
+ // It should fail with ErrNoCachedConn, because the pool contains no conns.
+ req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
+ rt := tt.roundTrip(req)
+ if err := rt.err(); !errors.Is(err, ErrNoCachedConn) {
+ t.Fatalf("RoundTrip after broken conn expires: got %v, want ErrNoCachedConn", err)
+ }
+}
+
+func TestExtendedConnectClientWithServerSupport(t *testing.T) {
+ setForTest(t, &disableExtendedConnectProtocol, false)
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ if r.Header.Get(":protocol") != "extended-connect" {
+ t.Fatalf("unexpected :protocol header received")
+ }
+ t.Log(io.Copy(w, r.Body))
+ })
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ AllowHTTP: true,
+ }
+ defer tr.CloseIdleConnections()
+ pr, pw := io.Pipe()
+ pwDone := make(chan struct{})
+ req, _ := http.NewRequest("CONNECT", ts.URL, pr)
+ req.Header.Set(":protocol", "extended-connect")
+ req.Header.Set("X-A", "A")
+ req.Header.Set("X-B", "B")
+ req.Header.Set("X-C", "C")
+ go func() {
+ pw.Write([]byte("hello, extended connect"))
+ pw.Close()
+ close(pwDone)
+ }()
+
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(body, []byte("hello, extended connect")) {
+ t.Fatal("unexpected body received")
+ }
+}
+
+func TestExtendedConnectClientWithoutServerSupport(t *testing.T) {
+ setForTest(t, &disableExtendedConnectProtocol, true)
+ ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
+ io.Copy(w, r.Body)
+ })
+ tr := &Transport{
+ TLSClientConfig: tlsConfigInsecure,
+ AllowHTTP: true,
+ }
+ defer tr.CloseIdleConnections()
+ pr, pw := io.Pipe()
+ pwDone := make(chan struct{})
+ req, _ := http.NewRequest("CONNECT", ts.URL, pr)
+ req.Header.Set(":protocol", "extended-connect")
+ req.Header.Set("X-A", "A")
+ req.Header.Set("X-B", "B")
+ req.Header.Set("X-C", "C")
+ go func() {
+ pw.Write([]byte("hello, extended connect"))
+ pw.Close()
+ close(pwDone)
+ }()
+
+ _, err := tr.RoundTrip(req)
+ if !errors.Is(err, errExtendedConnectNotSupported) {
+ t.Fatalf("expected error errExtendedConnectNotSupported, got: %v", err)
+ }
+}
+
+// Issue #70658: Make sure extended CONNECT requests don't get stuck if a
+// connection fails early in its lifetime.
+func TestExtendedConnectReadFrameError(t *testing.T) {
+ tc := newTestClientConn(t)
+ tc.wantFrameType(FrameSettings)
+ tc.wantFrameType(FrameWindowUpdate)
+
+ req, _ := http.NewRequest("CONNECT", "https://dummy.tld/", nil)
+ req.Header.Set(":protocol", "extended-connect")
+ rt := tc.roundTrip(req)
+ tc.wantIdle() // waiting for SETTINGS response
+
+ tc.closeWrite() // connection breaks without sending SETTINGS
+ if !rt.done() {
+ t.Fatalf("after connection closed: RoundTrip still running; want done")
+ }
+ if rt.err() == nil {
+ t.Fatalf("after connection closed: RoundTrip succeeded; want error")
+ }
+}
diff --git a/http2/unencrypted.go b/http2/unencrypted.go
new file mode 100644
index 0000000000..b2de211613
--- /dev/null
+++ b/http2/unencrypted.go
@@ -0,0 +1,32 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http2
+
+import (
+ "crypto/tls"
+ "errors"
+ "net"
+)
+
+const nextProtoUnencryptedHTTP2 = "unencrypted_http2"
+
+// unencryptedNetConnFromTLSConn retrieves a net.Conn wrapped in a *tls.Conn.
+//
+// TLSNextProto functions accept a *tls.Conn.
+//
+// When passing an unencrypted HTTP/2 connection to a TLSNextProto function,
+// we pass a *tls.Conn with an underlying net.Conn containing the unencrypted connection.
+// To be extra careful about mistakes (accidentally dropping TLS encryption in a place
+// where we want it), the tls.Conn contains a net.Conn with an UnencryptedNetConn method
+// that returns the actual connection we want to use.
+func unencryptedNetConnFromTLSConn(tc *tls.Conn) (net.Conn, error) {
+ conner, ok := tc.NetConn().(interface {
+ UnencryptedNetConn() net.Conn
+ })
+ if !ok {
+ return nil, errors.New("http2: TLS conn unexpectedly found in unencrypted handoff")
+ }
+ return conner.UnencryptedNetConn(), nil
+}
diff --git a/http2/write.go b/http2/write.go
index 33f61398a1..fdb35b9477 100644
--- a/http2/write.go
+++ b/http2/write.go
@@ -13,6 +13,7 @@ import (
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2/hpack"
+ "golang.org/x/net/internal/httpcommon"
)
// writeFramer is implemented by any type that is used to write frames.
@@ -131,6 +132,16 @@ func (se StreamError) writeFrame(ctx writeContext) error {
func (se StreamError) staysWithinBuffer(max int) bool { return frameHeaderLen+4 <= max }
+type writePing struct {
+ data [8]byte
+}
+
+func (w writePing) writeFrame(ctx writeContext) error {
+ return ctx.Framer().WritePing(false, w.data)
+}
+
+func (w writePing) staysWithinBuffer(max int) bool { return frameHeaderLen+len(w.data) <= max }
+
type writePingAck struct{ pf *PingFrame }
func (w writePingAck) writeFrame(ctx writeContext) error {
@@ -341,7 +352,7 @@ func encodeHeaders(enc *hpack.Encoder, h http.Header, keys []string) {
}
for _, k := range keys {
vv := h[k]
- k, ascii := lowerHeader(k)
+ k, ascii := httpcommon.LowerHeader(k)
if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x).
diff --git a/http2/gate_test.go b/internal/gate/gate.go
similarity index 52%
rename from http2/gate_test.go
rename to internal/gate/gate.go
index e5e6a315be..5c026c002d 100644
--- a/http2/gate_test.go
+++ b/internal/gate/gate.go
@@ -1,40 +1,37 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package http2
+
+// Package gate contains an alternative condition variable.
+package gate
import "context"
-// An gate is a monitor (mutex + condition variable) with one bit of state.
+// A Gate is a monitor (mutex + condition variable) with one bit of state.
//
// The condition may be either set or unset.
// Lock operations may be unconditional, or wait for the condition to be set.
// Unlock operations record the new state of the condition.
-type gate struct {
+type Gate struct {
// When unlocked, exactly one of set or unset contains a value.
// When locked, neither chan contains a value.
set chan struct{}
unset chan struct{}
}
-// newGate returns a new, unlocked gate with the condition unset.
-func newGate() gate {
- g := newLockedGate()
- g.unlock(false)
- return g
-}
-
-// newLocked gate returns a new, locked gate.
-func newLockedGate() gate {
- return gate{
+// New returns a new, unlocked gate.
+func New(set bool) Gate {
+ g := Gate{
set: make(chan struct{}, 1),
unset: make(chan struct{}, 1),
}
+ g.Unlock(set)
+ return g
}
-// lock acquires the gate unconditionally.
+// Lock acquires the gate unconditionally.
// It reports whether the condition is set.
-func (g *gate) lock() (set bool) {
+func (g *Gate) Lock() (set bool) {
select {
case <-g.set:
return true
@@ -43,9 +40,9 @@ func (g *gate) lock() (set bool) {
}
}
-// waitAndLock waits until the condition is set before acquiring the gate.
-// If the context expires, waitAndLock returns an error and does not acquire the gate.
-func (g *gate) waitAndLock(ctx context.Context) error {
+// WaitAndLock waits until the condition is set before acquiring the gate.
+// If the context expires, WaitAndLock returns an error and does not acquire the gate.
+func (g *Gate) WaitAndLock(ctx context.Context) error {
select {
case <-g.set:
return nil
@@ -59,8 +56,8 @@ func (g *gate) waitAndLock(ctx context.Context) error {
}
}
-// lockIfSet acquires the gate if and only if the condition is set.
-func (g *gate) lockIfSet() (acquired bool) {
+// LockIfSet acquires the gate if and only if the condition is set.
+func (g *Gate) LockIfSet() (acquired bool) {
select {
case <-g.set:
return true
@@ -69,17 +66,11 @@ func (g *gate) lockIfSet() (acquired bool) {
}
}
-// unlock sets the condition and releases the gate.
-func (g *gate) unlock(set bool) {
+// Unlock sets the condition and releases the gate.
+func (g *Gate) Unlock(set bool) {
if set {
g.set <- struct{}{}
} else {
g.unset <- struct{}{}
}
}
-
-// unlock sets the condition to the result of f and releases the gate.
-// Useful in defers.
-func (g *gate) unlockFunc(f func() bool) {
- g.unlock(f())
-}
diff --git a/internal/gate/gate_test.go b/internal/gate/gate_test.go
new file mode 100644
index 0000000000..87a78b15af
--- /dev/null
+++ b/internal/gate/gate_test.go
@@ -0,0 +1,85 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gate_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "golang.org/x/net/internal/gate"
+)
+
+func TestGateLockAndUnlock(t *testing.T) {
+ g := gate.New(false)
+ if set := g.Lock(); set {
+ t.Errorf("g.Lock of never-locked gate: true, want false")
+ }
+ unlockedc := make(chan struct{})
+ donec := make(chan struct{})
+ go func() {
+ defer close(donec)
+ if set := g.Lock(); !set {
+ t.Errorf("g.Lock of set gate: false, want true")
+ }
+ select {
+ case <-unlockedc:
+ default:
+ t.Errorf("g.Lock succeeded while gate was held")
+ }
+ g.Unlock(false)
+ }()
+ time.Sleep(1 * time.Millisecond)
+ close(unlockedc)
+ g.Unlock(true)
+ <-donec
+ if set := g.Lock(); set {
+ t.Errorf("g.Lock of unset gate: true, want false")
+ }
+}
+
+func TestGateWaitAndLock(t *testing.T) {
+ g := gate.New(false)
+ // WaitAndLock is canceled.
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+ if err := g.WaitAndLock(ctx); err != context.DeadlineExceeded {
+ t.Fatalf("g.WaitAndLock = %v, want context.DeadlineExceeded", err)
+ }
+ // WaitAndLock succeeds.
+ set := false
+ go func() {
+ time.Sleep(1 * time.Millisecond)
+ g.Lock()
+ set = true
+ g.Unlock(true)
+ }()
+ if err := g.WaitAndLock(context.Background()); err != nil {
+ t.Fatalf("g.WaitAndLock = %v, want nil", err)
+ }
+ if !set {
+ t.Fatalf("g.WaitAndLock returned before gate was set")
+ }
+ g.Unlock(true)
+ // WaitAndLock succeeds when the gate is set and the context is canceled.
+ if err := g.WaitAndLock(ctx); err != nil {
+ t.Fatalf("g.WaitAndLock = %v, want nil", err)
+ }
+}
+
+func TestGateLockIfSet(t *testing.T) {
+ g := gate.New(false)
+ if locked := g.LockIfSet(); locked {
+ t.Fatalf("g.LockIfSet of unset gate = %v, want false", locked)
+ }
+ g.Lock()
+ if locked := g.LockIfSet(); locked {
+ t.Fatalf("g.LockIfSet of locked gate = %v, want false", locked)
+ }
+ g.Unlock(true)
+ if locked := g.LockIfSet(); !locked {
+ t.Fatalf("g.LockIfSet of set gate = %v, want true", locked)
+ }
+}
diff --git a/internal/http3/body.go b/internal/http3/body.go
new file mode 100644
index 0000000000..cdde482efb
--- /dev/null
+++ b/internal/http3/body.go
@@ -0,0 +1,142 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "sync"
+)
+
+// A bodyWriter writes a request or response body to a stream
+// as a series of DATA frames.
+type bodyWriter struct {
+ st *stream
+ remain int64 // -1 when content-length is not known
+ flush bool // flush the stream after every write
+ name string // "request" or "response"
+}
+
+func (w *bodyWriter) Write(p []byte) (n int, err error) {
+ if w.remain >= 0 && int64(len(p)) > w.remain {
+ return 0, &streamError{
+ code: errH3InternalError,
+ message: w.name + " body longer than specified content length",
+ }
+ }
+ w.st.writeVarint(int64(frameTypeData))
+ w.st.writeVarint(int64(len(p)))
+ n, err = w.st.Write(p)
+ if w.remain >= 0 {
+ w.remain -= int64(n)
+ }
+ if w.flush && err == nil {
+ err = w.st.Flush()
+ }
+ if err != nil {
+ err = fmt.Errorf("writing %v body: %w", w.name, err)
+ }
+ return n, err
+}
+
+func (w *bodyWriter) Close() error {
+ if w.remain > 0 {
+ return errors.New(w.name + " body shorter than specified content length")
+ }
+ return nil
+}
+
+// A bodyReader reads a request or response body from a stream.
+type bodyReader struct {
+ st *stream
+
+ mu sync.Mutex
+ remain int64
+ err error
+}
+
+func (r *bodyReader) Read(p []byte) (n int, err error) {
+ // The HTTP/1 and HTTP/2 implementations both permit concurrent reads from a body,
+ // in the sense that the race detector won't complain.
+ // Use a mutex here to provide the same behavior.
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ if r.err != nil {
+ return 0, r.err
+ }
+ defer func() {
+ if err != nil {
+ r.err = err
+ }
+ }()
+ if r.st.lim == 0 {
+ // We've finished reading the previous DATA frame, so end it.
+ if err := r.st.endFrame(); err != nil {
+ return 0, err
+ }
+ }
+ // Read the next DATA frame header,
+ // if we aren't already in the middle of one.
+ for r.st.lim < 0 {
+ ftype, err := r.st.readFrameHeader()
+ if err == io.EOF && r.remain > 0 {
+ return 0, &streamError{
+ code: errH3MessageError,
+ message: "body shorter than content-length",
+ }
+ }
+ if err != nil {
+ return 0, err
+ }
+ switch ftype {
+ case frameTypeData:
+ if r.remain >= 0 && r.st.lim > r.remain {
+ return 0, &streamError{
+ code: errH3MessageError,
+ message: "body longer than content-length",
+ }
+ }
+ // Fall out of the loop and process the frame body below.
+ case frameTypeHeaders:
+ // This HEADERS frame contains the message trailers.
+ if r.remain > 0 {
+ return 0, &streamError{
+ code: errH3MessageError,
+ message: "body shorter than content-length",
+ }
+ }
+ // TODO: Fill in Request.Trailer.
+ if err := r.st.discardFrame(); err != nil {
+ return 0, err
+ }
+ return 0, io.EOF
+ default:
+ if err := r.st.discardUnknownFrame(ftype); err != nil {
+ return 0, err
+ }
+ }
+ }
+ // We are now reading the content of a DATA frame.
+ // Fill the read buffer or read to the end of the frame,
+ // whichever comes first.
+ if int64(len(p)) > r.st.lim {
+ p = p[:r.st.lim]
+ }
+ n, err = r.st.Read(p)
+ if r.remain > 0 {
+ r.remain -= int64(n)
+ }
+ return n, err
+}
+
+func (r *bodyReader) Close() error {
+ // Unlike the HTTP/1 and HTTP/2 body readers (at the time of this comment being written),
+ // calling Close concurrently with Read will interrupt the read.
+ r.st.stream.CloseRead()
+ return nil
+}
diff --git a/internal/http3/body_test.go b/internal/http3/body_test.go
new file mode 100644
index 0000000000..599e0df816
--- /dev/null
+++ b/internal/http3/body_test.go
@@ -0,0 +1,276 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "net/http"
+ "testing"
+)
+
+// TestReadData tests servers reading request bodies, and clients reading response bodies.
+func TestReadData(t *testing.T) {
+ // These tests consist of a series of steps,
+ // where each step is either something arriving on the stream
+ // or the client/server reading from the body.
+ type (
+ // HEADERS frame arrives (headers).
+ receiveHeaders struct {
+ contentLength int64 // -1 for no content-length
+ }
+ // DATA frame header arrives.
+ receiveDataHeader struct {
+ size int64
+ }
+ // DATA frame content arrives.
+ receiveData struct {
+ size int64
+ }
+ // HEADERS frame arrives (trailers).
+ receiveTrailers struct{}
+ // Some other frame arrives.
+ receiveFrame struct {
+ ftype frameType
+ data []byte
+ }
+ // Stream closed, ending the body.
+ receiveEOF struct{}
+ // Server reads from Request.Body, or client reads from Response.Body.
+ wantBody struct {
+ size int64
+ eof bool
+ }
+ wantError struct{}
+ )
+ for _, test := range []struct {
+ name string
+ respHeader http.Header
+ steps []any
+ wantError bool
+ }{{
+ name: "no content length",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ receiveEOF{},
+ wantBody{size: 10, eof: true},
+ },
+ }, {
+ name: "valid content length",
+ steps: []any{
+ receiveHeaders{contentLength: 10},
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ receiveEOF{},
+ wantBody{size: 10, eof: true},
+ },
+ }, {
+ name: "data frame exceeds content length",
+ steps: []any{
+ receiveHeaders{contentLength: 5},
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ wantError{},
+ },
+ }, {
+ name: "data frame after all content read",
+ steps: []any{
+ receiveHeaders{contentLength: 5},
+ receiveDataHeader{size: 5},
+ receiveData{size: 5},
+ wantBody{size: 5},
+ receiveDataHeader{size: 1},
+ receiveData{size: 1},
+ wantError{},
+ },
+ }, {
+ name: "content length too long",
+ steps: []any{
+ receiveHeaders{contentLength: 10},
+ receiveDataHeader{size: 5},
+ receiveData{size: 5},
+ receiveEOF{},
+ wantBody{size: 5},
+ wantError{},
+ },
+ }, {
+ name: "stream ended by trailers",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 5},
+ receiveData{size: 5},
+ receiveTrailers{},
+ wantBody{size: 5, eof: true},
+ },
+ }, {
+ name: "trailers and content length too long",
+ steps: []any{
+ receiveHeaders{contentLength: 10},
+ receiveDataHeader{size: 5},
+ receiveData{size: 5},
+ wantBody{size: 5},
+ receiveTrailers{},
+ wantError{},
+ },
+ }, {
+ name: "unknown frame before headers",
+ steps: []any{
+ receiveFrame{
+ ftype: 0x1f + 0x21, // reserved frame type
+ data: []byte{1, 2, 3, 4},
+ },
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ wantBody{size: 10},
+ },
+ }, {
+ name: "unknown frame after headers",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveFrame{
+ ftype: 0x1f + 0x21, // reserved frame type
+ data: []byte{1, 2, 3, 4},
+ },
+ receiveDataHeader{size: 10},
+ receiveData{size: 10},
+ wantBody{size: 10},
+ },
+ }, {
+ name: "invalid frame",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveFrame{
+ ftype: frameTypeSettings, // not a valid frame on this stream
+ data: []byte{1, 2, 3, 4},
+ },
+ wantError{},
+ },
+ }, {
+ name: "data frame consumed by several reads",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 16},
+ receiveData{size: 16},
+ wantBody{size: 2},
+ wantBody{size: 4},
+ wantBody{size: 8},
+ wantBody{size: 2},
+ },
+ }, {
+ name: "read multiple frames",
+ steps: []any{
+ receiveHeaders{contentLength: -1},
+ receiveDataHeader{size: 2},
+ receiveData{size: 2},
+ receiveDataHeader{size: 4},
+ receiveData{size: 4},
+ receiveDataHeader{size: 8},
+ receiveData{size: 8},
+ wantBody{size: 2},
+ wantBody{size: 4},
+ wantBody{size: 8},
+ },
+ }} {
+
+ runTest := func(t testing.TB, h http.Header, st *testQUICStream, body func() io.ReadCloser) {
+ var (
+ bytesSent int
+ bytesReceived int
+ )
+ for _, step := range test.steps {
+ switch step := step.(type) {
+ case receiveHeaders:
+ header := h.Clone()
+ if step.contentLength != -1 {
+ header["content-length"] = []string{
+ fmt.Sprint(step.contentLength),
+ }
+ }
+ st.writeHeaders(header)
+ case receiveDataHeader:
+ t.Logf("receive DATA frame header: size=%v", step.size)
+ st.writeVarint(int64(frameTypeData))
+ st.writeVarint(step.size)
+ st.Flush()
+ case receiveData:
+ t.Logf("receive DATA frame content: size=%v", step.size)
+ for range step.size {
+ st.stream.stream.WriteByte(byte(bytesSent))
+ bytesSent++
+ }
+ st.Flush()
+ case receiveTrailers:
+ st.writeHeaders(http.Header{
+ "x-trailer": []string{"trailer"},
+ })
+ case receiveFrame:
+ st.writeVarint(int64(step.ftype))
+ st.writeVarint(int64(len(step.data)))
+ st.Write(step.data)
+ st.Flush()
+ case receiveEOF:
+ t.Logf("receive EOF on request stream")
+ st.stream.stream.CloseWrite()
+ case wantBody:
+ t.Logf("read %v bytes from response body", step.size)
+ want := make([]byte, step.size)
+ for i := range want {
+ want[i] = byte(bytesReceived)
+ bytesReceived++
+ }
+ got := make([]byte, step.size)
+ n, err := body().Read(got)
+ got = got[:n]
+ if !bytes.Equal(got, want) {
+ t.Errorf("resp.Body.Read:")
+ t.Errorf(" got: {%x}", got)
+ t.Fatalf(" want: {%x}", want)
+ }
+ if err != nil {
+ if step.eof && err == io.EOF {
+ continue
+ }
+ t.Fatalf("resp.Body.Read: unexpected error %v", err)
+ }
+ if step.eof {
+ if n, err := body().Read([]byte{0}); n != 0 || err != io.EOF {
+ t.Fatalf("resp.Body.Read() = %v, %v; want io.EOF", n, err)
+ }
+ }
+ case wantError:
+ if n, err := body().Read([]byte{0}); n != 0 || err == nil || err == io.EOF {
+ t.Fatalf("resp.Body.Read() = %v, %v; want error", n, err)
+ }
+ default:
+ t.Fatalf("unknown test step %T", step)
+ }
+ }
+
+ }
+
+ runSynctestSubtest(t, test.name+"/client", func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+
+ header := http.Header{
+ ":status": []string{"200"},
+ }
+ runTest(t, header, st, func() io.ReadCloser {
+ return rt.response().Body
+ })
+ })
+ }
+}
diff --git a/internal/http3/conn.go b/internal/http3/conn.go
new file mode 100644
index 0000000000..5eb803115e
--- /dev/null
+++ b/internal/http3/conn.go
@@ -0,0 +1,133 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "context"
+ "io"
+ "sync"
+
+ "golang.org/x/net/quic"
+)
+
+type streamHandler interface {
+ handleControlStream(*stream) error
+ handlePushStream(*stream) error
+ handleEncoderStream(*stream) error
+ handleDecoderStream(*stream) error
+ handleRequestStream(*stream) error
+ abort(error)
+}
+
+type genericConn struct {
+ mu sync.Mutex
+
+ // The peer may create exactly one control, encoder, and decoder stream.
+ // streamsCreated is a bitset of streams created so far.
+ // Bits are 1 << streamType.
+ streamsCreated uint8
+}
+
+func (c *genericConn) acceptStreams(qconn *quic.Conn, h streamHandler) {
+ for {
+ // Use context.Background: This blocks until a stream is accepted
+ // or the connection closes.
+ st, err := qconn.AcceptStream(context.Background())
+ if err != nil {
+ return // connection closed
+ }
+ if st.IsReadOnly() {
+ go c.handleUnidirectionalStream(newStream(st), h)
+ } else {
+ go c.handleRequestStream(newStream(st), h)
+ }
+ }
+}
+
+func (c *genericConn) handleUnidirectionalStream(st *stream, h streamHandler) {
+ // Unidirectional stream header: One varint with the stream type.
+ v, err := st.readVarint()
+ if err != nil {
+ h.abort(&connectionError{
+ code: errH3StreamCreationError,
+ message: "error reading unidirectional stream header",
+ })
+ return
+ }
+ stype := streamType(v)
+ if err := c.checkStreamCreation(stype); err != nil {
+ h.abort(err)
+ return
+ }
+ switch stype {
+ case streamTypeControl:
+ err = h.handleControlStream(st)
+ case streamTypePush:
+ err = h.handlePushStream(st)
+ case streamTypeEncoder:
+ err = h.handleEncoderStream(st)
+ case streamTypeDecoder:
+ err = h.handleDecoderStream(st)
+ default:
+ // "Recipients of unknown stream types MUST either abort reading
+ // of the stream or discard incoming data without further processing."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2-7
+ //
+ // We should send the H3_STREAM_CREATION_ERROR error code,
+ // but the quic package currently doesn't allow setting error codes
+ // for STOP_SENDING frames.
+ // TODO: Should CloseRead take an error code?
+ err = nil
+ }
+ if err == io.EOF {
+ err = &connectionError{
+ code: errH3ClosedCriticalStream,
+ message: streamType(stype).String() + " stream closed",
+ }
+ }
+ c.handleStreamError(st, h, err)
+}
+
+func (c *genericConn) handleRequestStream(st *stream, h streamHandler) {
+ c.handleStreamError(st, h, h.handleRequestStream(st))
+}
+
+func (c *genericConn) handleStreamError(st *stream, h streamHandler, err error) {
+ switch err := err.(type) {
+ case *connectionError:
+ h.abort(err)
+ case nil:
+ st.stream.CloseRead()
+ st.stream.CloseWrite()
+ case *streamError:
+ st.stream.CloseRead()
+ st.stream.Reset(uint64(err.code))
+ default:
+ st.stream.CloseRead()
+ st.stream.Reset(uint64(errH3InternalError))
+ }
+}
+
+func (c *genericConn) checkStreamCreation(stype streamType) error {
+ switch stype {
+ case streamTypeControl, streamTypeEncoder, streamTypeDecoder:
+ // The peer may create exactly one control, encoder, and decoder stream.
+ default:
+ return nil
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ bit := uint8(1) << stype
+ if c.streamsCreated&bit != 0 {
+ return &connectionError{
+ code: errH3StreamCreationError,
+ message: "multiple " + stype.String() + " streams created",
+ }
+ }
+ c.streamsCreated |= bit
+ return nil
+}
diff --git a/internal/http3/conn_test.go b/internal/http3/conn_test.go
new file mode 100644
index 0000000000..a9afb1f9e9
--- /dev/null
+++ b/internal/http3/conn_test.go
@@ -0,0 +1,154 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "testing"
+ "testing/synctest"
+)
+
+// Tests which apply to both client and server connections.
+
+func TestConnCreatesControlStream(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ controlStream := tc.wantStream(streamTypeControl)
+ controlStream.wantFrameHeader(
+ "server sends SETTINGS frame on control stream",
+ frameTypeSettings)
+ controlStream.discardFrame()
+ })
+}
+
+func TestConnUnknownUnidirectionalStream(t *testing.T) {
+ // "Recipients of unknown stream types MUST either abort reading of the stream
+ // or discard incoming data without further processing."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2-7
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ st := tc.newStream(0x21) // reserved stream type
+
+ // The endpoint should send a STOP_SENDING for this stream,
+ // but it should not close the connection.
+ synctest.Wait()
+ if _, err := st.Write([]byte("hello")); err == nil {
+ t.Fatalf("write to send-only stream with an unknown type succeeded; want error")
+ }
+ tc.wantNotClosed("after receiving unknown unidirectional stream type")
+ })
+}
+
+func TestConnUnknownSettings(t *testing.T) {
+ // "An implementation MUST ignore any [settings] parameter with
+ // an identifier it does not understand."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-9
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ controlStream := tc.newStream(streamTypeControl)
+ controlStream.writeSettings(0x1f+0x21, 0) // reserved settings type
+ controlStream.Flush()
+ tc.wantNotClosed("after receiving unknown settings")
+ })
+}
+
+func TestConnInvalidSettings(t *testing.T) {
+ // "These reserved settings MUST NOT be sent, and their receipt MUST
+ // be treated as a connection error of type H3_SETTINGS_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1-5
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ controlStream := tc.newStream(streamTypeControl)
+ controlStream.writeSettings(0x02, 0) // HTTP/2 SETTINGS_ENABLE_PUSH
+ controlStream.Flush()
+ tc.wantClosed("invalid setting", errH3SettingsError)
+ })
+}
+
+func TestConnDuplicateStream(t *testing.T) {
+ for _, stype := range []streamType{
+ streamTypeControl,
+ streamTypeEncoder,
+ streamTypeDecoder,
+ } {
+ t.Run(stype.String(), func(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ _ = tc.newStream(stype)
+ tc.wantNotClosed("after creating one " + stype.String() + " stream")
+
+ // Opening a second control, encoder, or decoder stream
+ // is a protocol violation.
+ _ = tc.newStream(stype)
+ tc.wantClosed("duplicate stream", errH3StreamCreationError)
+ })
+ })
+ }
+}
+
+func TestConnUnknownFrames(t *testing.T) {
+ for _, stype := range []streamType{
+ streamTypeControl,
+ } {
+ t.Run(stype.String(), func(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ st := tc.newStream(stype)
+
+ if stype == streamTypeControl {
+ // First frame on the control stream must be settings.
+ st.writeVarint(int64(frameTypeSettings))
+ st.writeVarint(0) // size
+ }
+
+ data := "frame content"
+ st.writeVarint(0x1f + 0x21) // reserved frame type
+ st.writeVarint(int64(len(data))) // size
+ st.Write([]byte(data))
+ st.Flush()
+
+ tc.wantNotClosed("after writing unknown frame")
+ })
+ })
+ }
+}
+
+func TestConnInvalidFrames(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ control := tc.newStream(streamTypeControl)
+
+ // SETTINGS frame.
+ control.writeVarint(int64(frameTypeSettings))
+ control.writeVarint(0) // size
+
+ // DATA frame (invalid on the control stream).
+ control.writeVarint(int64(frameTypeData))
+ control.writeVarint(0) // size
+ control.Flush()
+ tc.wantClosed("after writing DATA frame to control stream", errH3FrameUnexpected)
+ })
+}
+
+func TestConnPeerCreatesBadUnidirectionalStream(t *testing.T) {
+ runConnTest(t, func(t testing.TB, tc *testQUICConn) {
+ // Create and close a stream without sending the unidirectional stream header.
+ qs, err := tc.qconn.NewSendOnlyStream(canceledCtx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ st := newTestQUICStream(tc.t, newStream(qs))
+ st.stream.stream.Close()
+
+ tc.wantClosed("after peer creates and closes uni stream", errH3StreamCreationError)
+ })
+}
+
+func runConnTest(t *testing.T, f func(testing.TB, *testQUICConn)) {
+ t.Helper()
+ runSynctestSubtest(t, "client", func(t testing.TB) {
+ tc := newTestClientConn(t)
+ f(t, tc.testQUICConn)
+ })
+ runSynctestSubtest(t, "server", func(t testing.TB) {
+ ts := newTestServer(t)
+ tc := ts.connect()
+ f(t, tc.testQUICConn)
+ })
+}
diff --git a/internal/http3/doc.go b/internal/http3/doc.go
new file mode 100644
index 0000000000..5530113f69
--- /dev/null
+++ b/internal/http3/doc.go
@@ -0,0 +1,10 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package http3 implements the HTTP/3 protocol.
+//
+// This package is a work in progress.
+// It is not ready for production usage.
+// Its API is subject to change without notice.
+package http3
diff --git a/internal/http3/errors.go b/internal/http3/errors.go
new file mode 100644
index 0000000000..db46acfcc8
--- /dev/null
+++ b/internal/http3/errors.go
@@ -0,0 +1,104 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import "fmt"
+
+// http3Error is an HTTP/3 error code.
+type http3Error int
+
+const (
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-8.1
+ errH3NoError = http3Error(0x0100)
+ errH3GeneralProtocolError = http3Error(0x0101)
+ errH3InternalError = http3Error(0x0102)
+ errH3StreamCreationError = http3Error(0x0103)
+ errH3ClosedCriticalStream = http3Error(0x0104)
+ errH3FrameUnexpected = http3Error(0x0105)
+ errH3FrameError = http3Error(0x0106)
+ errH3ExcessiveLoad = http3Error(0x0107)
+ errH3IDError = http3Error(0x0108)
+ errH3SettingsError = http3Error(0x0109)
+ errH3MissingSettings = http3Error(0x010a)
+ errH3RequestRejected = http3Error(0x010b)
+ errH3RequestCancelled = http3Error(0x010c)
+ errH3RequestIncomplete = http3Error(0x010d)
+ errH3MessageError = http3Error(0x010e)
+ errH3ConnectError = http3Error(0x010f)
+ errH3VersionFallback = http3Error(0x0110)
+
+ // https://www.rfc-editor.org/rfc/rfc9204.html#section-8.3
+ errQPACKDecompressionFailed = http3Error(0x0200)
+ errQPACKEncoderStreamError = http3Error(0x0201)
+ errQPACKDecoderStreamError = http3Error(0x0202)
+)
+
+func (e http3Error) Error() string {
+ switch e {
+ case errH3NoError:
+ return "H3_NO_ERROR"
+ case errH3GeneralProtocolError:
+ return "H3_GENERAL_PROTOCOL_ERROR"
+ case errH3InternalError:
+ return "H3_INTERNAL_ERROR"
+ case errH3StreamCreationError:
+ return "H3_STREAM_CREATION_ERROR"
+ case errH3ClosedCriticalStream:
+ return "H3_CLOSED_CRITICAL_STREAM"
+ case errH3FrameUnexpected:
+ return "H3_FRAME_UNEXPECTED"
+ case errH3FrameError:
+ return "H3_FRAME_ERROR"
+ case errH3ExcessiveLoad:
+ return "H3_EXCESSIVE_LOAD"
+ case errH3IDError:
+ return "H3_ID_ERROR"
+ case errH3SettingsError:
+ return "H3_SETTINGS_ERROR"
+ case errH3MissingSettings:
+ return "H3_MISSING_SETTINGS"
+ case errH3RequestRejected:
+ return "H3_REQUEST_REJECTED"
+ case errH3RequestCancelled:
+ return "H3_REQUEST_CANCELLED"
+ case errH3RequestIncomplete:
+ return "H3_REQUEST_INCOMPLETE"
+ case errH3MessageError:
+ return "H3_MESSAGE_ERROR"
+ case errH3ConnectError:
+ return "H3_CONNECT_ERROR"
+ case errH3VersionFallback:
+ return "H3_VERSION_FALLBACK"
+ case errQPACKDecompressionFailed:
+ return "QPACK_DECOMPRESSION_FAILED"
+ case errQPACKEncoderStreamError:
+ return "QPACK_ENCODER_STREAM_ERROR"
+ case errQPACKDecoderStreamError:
+ return "QPACK_DECODER_STREAM_ERROR"
+ }
+ return fmt.Sprintf("H3_ERROR_%v", int(e))
+}
+
+// A streamError is an error which terminates a stream, but not the connection.
+// https://www.rfc-editor.org/rfc/rfc9114.html#section-8-1
+type streamError struct {
+ code http3Error
+ message string
+}
+
+func (e *streamError) Error() string { return e.message }
+func (e *streamError) Unwrap() error { return e.code }
+
+// A connectionError is an error which results in the entire connection closing.
+// https://www.rfc-editor.org/rfc/rfc9114.html#section-8-2
+type connectionError struct {
+ code http3Error
+ message string
+}
+
+func (e *connectionError) Error() string { return e.message }
+func (e *connectionError) Unwrap() error { return e.code }
diff --git a/internal/http3/files_test.go b/internal/http3/files_test.go
new file mode 100644
index 0000000000..9c97a6ced4
--- /dev/null
+++ b/internal/http3/files_test.go
@@ -0,0 +1,56 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "os"
+ "strings"
+ "testing"
+)
+
+// TestFiles checks that every file in this package has a build constraint on Go 1.24.
+//
+// Package tests rely on testing/synctest, added as an experiment in Go 1.24.
+// When moving internal/http3 to an importable location, we can decide whether
+// to relax the constraint for non-test files.
+//
+// Drop this test when the x/net go.mod depends on 1.24 or newer.
+func TestFiles(t *testing.T) {
+ f, err := os.Open(".")
+ if err != nil {
+ t.Fatal(err)
+ }
+ names, err := f.Readdirnames(-1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, name := range names {
+ if !strings.HasSuffix(name, ".go") {
+ continue
+ }
+ b, err := os.ReadFile(name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Check for copyright header while we're in here.
+ if !bytes.Contains(b, []byte("The Go Authors.")) {
+ t.Errorf("%v: missing copyright", name)
+ }
+ // doc.go doesn't need a build constraint.
+ if name == "doc.go" {
+ continue
+ }
+ if !bytes.Contains(b, []byte("//go:build go1.24")) {
+ t.Errorf("%v: missing constraint on go1.24", name)
+ }
+ if bytes.Contains(b, []byte(`"testing/synctest"`)) &&
+ !bytes.Contains(b, []byte("//go:build go1.24 && goexperiment.synctest")) {
+ t.Errorf("%v: missing constraint on go1.24 && goexperiment.synctest", name)
+ }
+ }
+}
diff --git a/internal/http3/http3.go b/internal/http3/http3.go
new file mode 100644
index 0000000000..1f60670564
--- /dev/null
+++ b/internal/http3/http3.go
@@ -0,0 +1,86 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import "fmt"
+
+// Stream types.
+//
+// For unidirectional streams, the value is the stream type sent over the wire.
+//
+// For bidirectional streams (which are always request streams),
+// the value is arbitrary and never sent on the wire.
+type streamType int64
+
+const (
+ // Bidirectional request stream.
+ // All bidirectional streams are request streams.
+ // This stream type is never sent over the wire.
+ //
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1
+ streamTypeRequest = streamType(-1)
+
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2
+ streamTypeControl = streamType(0x00)
+ streamTypePush = streamType(0x01)
+
+ // https://www.rfc-editor.org/rfc/rfc9204.html#section-4.2
+ streamTypeEncoder = streamType(0x02)
+ streamTypeDecoder = streamType(0x03)
+)
+
+func (stype streamType) String() string {
+ switch stype {
+ case streamTypeRequest:
+ return "request"
+ case streamTypeControl:
+ return "control"
+ case streamTypePush:
+ return "push"
+ case streamTypeEncoder:
+ return "encoder"
+ case streamTypeDecoder:
+ return "decoder"
+ default:
+ return "unknown"
+ }
+}
+
+// Frame types.
+type frameType int64
+
+const (
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2
+ frameTypeData = frameType(0x00)
+ frameTypeHeaders = frameType(0x01)
+ frameTypeCancelPush = frameType(0x03)
+ frameTypeSettings = frameType(0x04)
+ frameTypePushPromise = frameType(0x05)
+ frameTypeGoaway = frameType(0x07)
+ frameTypeMaxPushID = frameType(0x0d)
+)
+
+func (ftype frameType) String() string {
+ switch ftype {
+ case frameTypeData:
+ return "DATA"
+ case frameTypeHeaders:
+ return "HEADERS"
+ case frameTypeCancelPush:
+ return "CANCEL_PUSH"
+ case frameTypeSettings:
+ return "SETTINGS"
+ case frameTypePushPromise:
+ return "PUSH_PROMISE"
+ case frameTypeGoaway:
+ return "GOAWAY"
+ case frameTypeMaxPushID:
+ return "MAX_PUSH_ID"
+ default:
+ return fmt.Sprintf("UNKNOWN_%d", int64(ftype))
+ }
+}
diff --git a/internal/http3/http3_test.go b/internal/http3/http3_test.go
new file mode 100644
index 0000000000..f490ad3f03
--- /dev/null
+++ b/internal/http3/http3_test.go
@@ -0,0 +1,82 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "encoding/hex"
+ "os"
+ "slices"
+ "strings"
+ "testing"
+ "testing/synctest"
+)
+
+func init() {
+ // testing/synctest requires asynctimerchan=0 (the default as of Go 1.23),
+ // but the x/net go.mod is currently selecting go1.18.
+ //
+ // Set asynctimerchan=0 explicitly.
+ //
+ // TODO: Remove this when the x/net go.mod Go version is >= go1.23.
+ os.Setenv("GODEBUG", os.Getenv("GODEBUG")+",asynctimerchan=0")
+}
+
+// runSynctest runs f in a synctest.Run bubble.
+// It arranges for t.Cleanup functions to run within the bubble.
+func runSynctest(t *testing.T, f func(t testing.TB)) {
+ synctest.Run(func() {
+ ct := &cleanupT{T: t}
+ defer ct.done()
+ f(ct)
+ })
+}
+
+// runSynctestSubtest runs f in a subtest in a synctest.Run bubble.
+func runSynctestSubtest(t *testing.T, name string, f func(t testing.TB)) {
+ t.Run(name, func(t *testing.T) {
+ runSynctest(t, f)
+ })
+}
+
+// cleanupT wraps a testing.T and adds its own Cleanup method.
+// Used to execute cleanup functions within a synctest bubble.
+type cleanupT struct {
+ *testing.T
+ cleanups []func()
+}
+
+// Cleanup replaces T.Cleanup.
+func (t *cleanupT) Cleanup(f func()) {
+ t.cleanups = append(t.cleanups, f)
+}
+
+func (t *cleanupT) done() {
+ for _, f := range slices.Backward(t.cleanups) {
+ f()
+ }
+}
+
+func unhex(s string) []byte {
+ b, err := hex.DecodeString(strings.Map(func(c rune) rune {
+ switch c {
+ case ' ', '\t', '\n':
+ return -1 // ignore
+ }
+ return c
+ }, s))
+ if err != nil {
+ panic(err)
+ }
+ return b
+}
+
+// testReader implements io.Reader.
+type testReader struct {
+ readFunc func([]byte) (int, error)
+}
+
+func (r testReader) Read(p []byte) (n int, err error) { return r.readFunc(p) }
diff --git a/internal/http3/qpack.go b/internal/http3/qpack.go
new file mode 100644
index 0000000000..66f4e29762
--- /dev/null
+++ b/internal/http3/qpack.go
@@ -0,0 +1,334 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "errors"
+ "io"
+
+ "golang.org/x/net/http2/hpack"
+)
+
+// QPACK (RFC 9204) header compression wire encoding.
+// https://www.rfc-editor.org/rfc/rfc9204.html
+
+// tableType is the static or dynamic table.
+//
+// The T bit in QPACK instructions indicates whether a table index refers to
+// the dynamic (T=0) or static (T=1) table. tableTypeForTBit and tableType.tbit
+// convert a T bit from the wire encoding to/from a tableType.
+type tableType byte
+
+const (
+ dynamicTable = 0x00 // T=0, dynamic table
+ staticTable = 0xff // T=1, static table
+)
+
+// tableTypeForTbit returns the table type corresponding to a T bit value.
+// The input parameter contains a byte masked to contain only the T bit.
+func tableTypeForTbit(bit byte) tableType {
+ if bit == 0 {
+ return dynamicTable
+ }
+ return staticTable
+}
+
+// tbit produces the T bit corresponding to the table type.
+// The input parameter contains a byte with the T bit set to 1,
+// and the return is either the input or 0 depending on the table type.
+func (t tableType) tbit(bit byte) byte {
+ return bit & byte(t)
+}
+
+// indexType indicates a literal's indexing status.
+//
+// The N bit in QPACK instructions indicates whether a literal is "never-indexed".
+// A never-indexed literal (N=1) must not be encoded as an indexed literal if it
+// forwarded on another connection.
+//
+// (See https://www.rfc-editor.org/rfc/rfc9204.html#section-7.1 for details on the
+// security reasons for never-indexed literals.)
+type indexType byte
+
+const (
+ mayIndex = 0x00 // N=0, not a never-indexed literal
+ neverIndex = 0xff // N=1, never-indexed literal
+)
+
+// indexTypeForNBit returns the index type corresponding to a N bit value.
+// The input parameter contains a byte masked to contain only the N bit.
+func indexTypeForNBit(bit byte) indexType {
+ if bit == 0 {
+ return mayIndex
+ }
+ return neverIndex
+}
+
+// nbit produces the N bit corresponding to the table type.
+// The input parameter contains a byte with the N bit set to 1,
+// and the return is either the input or 0 depending on the table type.
+func (t indexType) nbit(bit byte) byte {
+ return bit & byte(t)
+}
+
+// Indexed Field Line:
+//
+// 0 1 2 3 4 5 6 7
+// +---+---+---+---+---+---+---+---+
+// | 1 | T | Index (6+) |
+// +---+---+-----------------------+
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.2
+
+func appendIndexedFieldLine(b []byte, ttype tableType, index int) []byte {
+ const tbit = 0b_01000000
+ return appendPrefixedInt(b, 0b_1000_0000|ttype.tbit(tbit), 6, int64(index))
+}
+
+func (st *stream) decodeIndexedFieldLine(b byte) (itype indexType, name, value string, err error) {
+ index, err := st.readPrefixedIntWithByte(b, 6)
+ if err != nil {
+ return 0, "", "", err
+ }
+ const tbit = 0b_0100_0000
+ if tableTypeForTbit(b&tbit) == staticTable {
+ ent, err := staticTableEntry(index)
+ if err != nil {
+ return 0, "", "", err
+ }
+ return mayIndex, ent.name, ent.value, nil
+ } else {
+ return 0, "", "", errors.New("dynamic table is not supported yet")
+ }
+}
+
+// Literal Field Line With Name Reference:
+//
+// 0 1 2 3 4 5 6 7
+// +---+---+---+---+---+---+---+---+
+// | 0 | 1 | N | T |Name Index (4+)|
+// +---+---+---+---+---------------+
+// | H | Value Length (7+) |
+// +---+---------------------------+
+// | Value String (Length bytes) |
+// +-------------------------------+
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.4
+
+func appendLiteralFieldLineWithNameReference(b []byte, ttype tableType, itype indexType, nameIndex int, value string) []byte {
+ const tbit = 0b_0001_0000
+ const nbit = 0b_0010_0000
+ b = appendPrefixedInt(b, 0b_0100_0000|itype.nbit(nbit)|ttype.tbit(tbit), 4, int64(nameIndex))
+ b = appendPrefixedString(b, 0, 7, value)
+ return b
+}
+
+func (st *stream) decodeLiteralFieldLineWithNameReference(b byte) (itype indexType, name, value string, err error) {
+ nameIndex, err := st.readPrefixedIntWithByte(b, 4)
+ if err != nil {
+ return 0, "", "", err
+ }
+
+ const tbit = 0b_0001_0000
+ if tableTypeForTbit(b&tbit) == staticTable {
+ ent, err := staticTableEntry(nameIndex)
+ if err != nil {
+ return 0, "", "", err
+ }
+ name = ent.name
+ } else {
+ return 0, "", "", errors.New("dynamic table is not supported yet")
+ }
+
+ _, value, err = st.readPrefixedString(7)
+ if err != nil {
+ return 0, "", "", err
+ }
+
+ const nbit = 0b_0010_0000
+ itype = indexTypeForNBit(b & nbit)
+
+ return itype, name, value, nil
+}
+
+// Literal Field Line with Literal Name:
+//
+// 0 1 2 3 4 5 6 7
+// +---+---+---+---+---+---+---+---+
+// | 0 | 0 | 1 | N | H |NameLen(3+)|
+// +---+---+---+---+---+-----------+
+// | Name String (Length bytes) |
+// +---+---------------------------+
+// | H | Value Length (7+) |
+// +---+---------------------------+
+// | Value String (Length bytes) |
+// +-------------------------------+
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.6
+
+func appendLiteralFieldLineWithLiteralName(b []byte, itype indexType, name, value string) []byte {
+ const nbit = 0b_0001_0000
+ b = appendPrefixedString(b, 0b_0010_0000|itype.nbit(nbit), 3, name)
+ b = appendPrefixedString(b, 0, 7, value)
+ return b
+}
+
+func (st *stream) decodeLiteralFieldLineWithLiteralName(b byte) (itype indexType, name, value string, err error) {
+ name, err = st.readPrefixedStringWithByte(b, 3)
+ if err != nil {
+ return 0, "", "", err
+ }
+ _, value, err = st.readPrefixedString(7)
+ if err != nil {
+ return 0, "", "", err
+ }
+ const nbit = 0b_0001_0000
+ itype = indexTypeForNBit(b & nbit)
+ return itype, name, value, nil
+}
+
+// Prefixed-integer encoding from RFC 7541, section 5.1
+//
+// Prefixed integers consist of some number of bits of data,
+// N bits of encoded integer, and 0 or more additional bytes of
+// encoded integer.
+//
+// The RFCs represent this as, for example:
+//
+// 0 1 2 3 4 5 6 7
+// +---+---+---+---+---+---+---+---+
+// | 0 | 0 | 1 | Capacity (5+) |
+// +---+---+---+-------------------+
+//
+// "Capacity" is an integer with a 5-bit prefix.
+//
+// In the following functions, a "prefixLen" parameter is the number
+// of integer bits in the first byte (5 in the above example), and
+// a "firstByte" parameter is a byte containing the first byte of
+// the encoded value (0x001x_xxxx in the above example).
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.1
+// https://www.rfc-editor.org/rfc/rfc7541#section-5.1
+
+// readPrefixedInt reads an RFC 7541 prefixed integer from st.
+func (st *stream) readPrefixedInt(prefixLen uint8) (firstByte byte, v int64, err error) {
+ firstByte, err = st.ReadByte()
+ if err != nil {
+ return 0, 0, errQPACKDecompressionFailed
+ }
+ v, err = st.readPrefixedIntWithByte(firstByte, prefixLen)
+ return firstByte, v, err
+}
+
+// readPrefixedInt reads an RFC 7541 prefixed integer from st.
+// The first byte has already been read from the stream.
+func (st *stream) readPrefixedIntWithByte(firstByte byte, prefixLen uint8) (v int64, err error) {
+ prefixMask := (byte(1) << prefixLen) - 1
+ v = int64(firstByte & prefixMask)
+ if v != int64(prefixMask) {
+ return v, nil
+ }
+ m := 0
+ for {
+ b, err := st.ReadByte()
+ if err != nil {
+ return 0, errQPACKDecompressionFailed
+ }
+ v += int64(b&127) << m
+ m += 7
+ if b&128 == 0 {
+ break
+ }
+ }
+ return v, err
+}
+
+// appendPrefixedInt appends an RFC 7541 prefixed integer to b.
+//
+// The firstByte parameter includes the non-integer bits of the first byte.
+// The other bits must be zero.
+func appendPrefixedInt(b []byte, firstByte byte, prefixLen uint8, i int64) []byte {
+ u := uint64(i)
+ prefixMask := (uint64(1) << prefixLen) - 1
+ if u < prefixMask {
+ return append(b, firstByte|byte(u))
+ }
+ b = append(b, firstByte|byte(prefixMask))
+ u -= prefixMask
+ for u >= 128 {
+ b = append(b, 0x80|byte(u&0x7f))
+ u >>= 7
+ }
+ return append(b, byte(u))
+}
+
+// String literal encoding from RFC 7541, section 5.2
+//
+// String literals consist of a single bit flag indicating
+// whether the string is Huffman-encoded, a prefixed integer (see above),
+// and the string.
+//
+// https://www.rfc-editor.org/rfc/rfc9204.html#section-4.1.2
+// https://www.rfc-editor.org/rfc/rfc7541#section-5.2
+
+// readPrefixedString reads an RFC 7541 string from st.
+func (st *stream) readPrefixedString(prefixLen uint8) (firstByte byte, s string, err error) {
+ firstByte, err = st.ReadByte()
+ if err != nil {
+ return 0, "", errQPACKDecompressionFailed
+ }
+ s, err = st.readPrefixedStringWithByte(firstByte, prefixLen)
+ return firstByte, s, err
+}
+
+// readPrefixedString reads an RFC 7541 string from st.
+// The first byte has already been read from the stream.
+func (st *stream) readPrefixedStringWithByte(firstByte byte, prefixLen uint8) (s string, err error) {
+ size, err := st.readPrefixedIntWithByte(firstByte, prefixLen)
+ if err != nil {
+ return "", errQPACKDecompressionFailed
+ }
+
+ hbit := byte(1) << prefixLen
+ isHuffman := firstByte&hbit != 0
+
+ // TODO: Avoid allocating here.
+ data := make([]byte, size)
+ if _, err := io.ReadFull(st, data); err != nil {
+ return "", errQPACKDecompressionFailed
+ }
+ if isHuffman {
+ // TODO: Move Huffman functions into a new package that hpack (HTTP/2)
+ // and this package can both import. Most of the hpack package isn't
+ // relevant to HTTP/3.
+ s, err := hpack.HuffmanDecodeToString(data)
+ if err != nil {
+ return "", errQPACKDecompressionFailed
+ }
+ return s, nil
+ }
+ return string(data), nil
+}
+
+// appendPrefixedString appends an RFC 7541 string to st,
+// applying Huffman encoding and setting the H bit (indicating Huffman encoding)
+// when appropriate.
+//
+// The firstByte parameter includes the non-integer bits of the first byte.
+// The other bits must be zero.
+func appendPrefixedString(b []byte, firstByte byte, prefixLen uint8, s string) []byte {
+ huffmanLen := hpack.HuffmanEncodeLength(s)
+ if huffmanLen < uint64(len(s)) {
+ hbit := byte(1) << prefixLen
+ b = appendPrefixedInt(b, firstByte|hbit, prefixLen, int64(huffmanLen))
+ b = hpack.AppendHuffmanString(b, s)
+ } else {
+ b = appendPrefixedInt(b, firstByte, prefixLen, int64(len(s)))
+ b = append(b, s...)
+ }
+ return b
+}
diff --git a/internal/http3/qpack_decode.go b/internal/http3/qpack_decode.go
new file mode 100644
index 0000000000..018867afb1
--- /dev/null
+++ b/internal/http3/qpack_decode.go
@@ -0,0 +1,83 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "errors"
+ "math/bits"
+)
+
+type qpackDecoder struct {
+ // The decoder has no state for now,
+ // but that'll change once we add dynamic table support.
+ //
+ // TODO: dynamic table support.
+}
+
+func (qd *qpackDecoder) decode(st *stream, f func(itype indexType, name, value string) error) error {
+ // Encoded Field Section prefix.
+
+ // We set SETTINGS_QPACK_MAX_TABLE_CAPACITY to 0,
+ // so the Required Insert Count must be 0.
+ _, requiredInsertCount, err := st.readPrefixedInt(8)
+ if err != nil {
+ return err
+ }
+ if requiredInsertCount != 0 {
+ return errQPACKDecompressionFailed
+ }
+
+ // Delta Base. We don't use the dynamic table yet, so this may be ignored.
+ _, _, err = st.readPrefixedInt(7)
+ if err != nil {
+ return err
+ }
+
+ sawNonPseudo := false
+ for st.lim > 0 {
+ firstByte, err := st.ReadByte()
+ if err != nil {
+ return err
+ }
+ var name, value string
+ var itype indexType
+ switch bits.LeadingZeros8(firstByte) {
+ case 0:
+ // Indexed Field Line
+ itype, name, value, err = st.decodeIndexedFieldLine(firstByte)
+ case 1:
+ // Literal Field Line With Name Reference
+ itype, name, value, err = st.decodeLiteralFieldLineWithNameReference(firstByte)
+ case 2:
+ // Literal Field Line with Literal Name
+ itype, name, value, err = st.decodeLiteralFieldLineWithLiteralName(firstByte)
+ case 3:
+ // Indexed Field Line With Post-Base Index
+ err = errors.New("dynamic table is not supported yet")
+ case 4:
+ // Indexed Field Line With Post-Base Name Reference
+ err = errors.New("dynamic table is not supported yet")
+ }
+ if err != nil {
+ return err
+ }
+ if len(name) == 0 {
+ return errH3MessageError
+ }
+ if name[0] == ':' {
+ if sawNonPseudo {
+ return errH3MessageError
+ }
+ } else {
+ sawNonPseudo = true
+ }
+ if err := f(itype, name, value); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/internal/http3/qpack_decode_test.go b/internal/http3/qpack_decode_test.go
new file mode 100644
index 0000000000..1b779aa782
--- /dev/null
+++ b/internal/http3/qpack_decode_test.go
@@ -0,0 +1,196 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestQPACKDecode(t *testing.T) {
+ type header struct {
+ itype indexType
+ name, value string
+ }
+ // Many test cases here taken from Google QUICHE,
+ // quiche/quic/core/qpack/qpack_encoder_test.cc.
+ for _, test := range []struct {
+ name string
+ enc []byte
+ want []header
+ }{{
+ name: "empty",
+ enc: unhex("0000"),
+ want: []header{},
+ }, {
+ name: "literal entry empty value",
+ enc: unhex("000023666f6f00"),
+ want: []header{
+ {mayIndex, "foo", ""},
+ },
+ }, {
+ name: "simple literal entry",
+ enc: unhex("000023666f6f03626172"),
+ want: []header{
+ {mayIndex, "foo", "bar"},
+ },
+ }, {
+ name: "multiple literal entries",
+ enc: unhex("0000" + // prefix
+ // foo: bar
+ "23666f6f03626172" +
+ // 7 octet long header name, the smallest number
+ // that does not fit on a 3-bit prefix.
+ "2700666f6f62616172" +
+ // 127 octet long header value, the smallest number
+ // that does not fit on a 7-bit prefix.
+ "7f00616161616161616161616161616161616161616161616161616161616161616161" +
+ "6161616161616161616161616161616161616161616161616161616161616161616161" +
+ "6161616161616161616161616161616161616161616161616161616161616161616161" +
+ "616161616161616161616161616161616161616161616161",
+ ),
+ want: []header{
+ {mayIndex, "foo", "bar"},
+ {mayIndex, "foobaar", strings.Repeat("a", 127)},
+ },
+ }, {
+ name: "line feed in value",
+ enc: unhex("000023666f6f0462610a72"),
+ want: []header{
+ {mayIndex, "foo", "ba\nr"},
+ },
+ }, {
+ name: "huffman simple",
+ enc: unhex("00002f0125a849e95ba97d7f8925a849e95bb8e8b4bf"),
+ want: []header{
+ {mayIndex, "custom-key", "custom-value"},
+ },
+ }, {
+ name: "alternating huffman nonhuffman",
+ enc: unhex("0000" + // Prefix.
+ "2f0125a849e95ba97d7f" + // Huffman-encoded name.
+ "8925a849e95bb8e8b4bf" + // Huffman-encoded value.
+ "2703637573746f6d2d6b6579" + // Non-Huffman encoded name.
+ "0c637573746f6d2d76616c7565" + // Non-Huffman encoded value.
+ "2f0125a849e95ba97d7f" + // Huffman-encoded name.
+ "0c637573746f6d2d76616c7565" + // Non-Huffman encoded value.
+ "2703637573746f6d2d6b6579" + // Non-Huffman encoded name.
+ "8925a849e95bb8e8b4bf", // Huffman-encoded value.
+ ),
+ want: []header{
+ {mayIndex, "custom-key", "custom-value"},
+ {mayIndex, "custom-key", "custom-value"},
+ {mayIndex, "custom-key", "custom-value"},
+ {mayIndex, "custom-key", "custom-value"},
+ },
+ }, {
+ name: "static table",
+ enc: unhex("0000d1d45f00055452414345dfcc5f108621e9aec2a11f5c8294e75f1000"),
+ want: []header{
+ {mayIndex, ":method", "GET"},
+ {mayIndex, ":method", "POST"},
+ {mayIndex, ":method", "TRACE"},
+ {mayIndex, "accept-encoding", "gzip, deflate, br"},
+ {mayIndex, "location", ""},
+ {mayIndex, "accept-encoding", "compress"},
+ {mayIndex, "location", "foo"},
+ {mayIndex, "accept-encoding", ""},
+ },
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ st1, st2 := newStreamPair(t)
+ st1.Write(test.enc)
+ st1.Flush()
+
+ st2.lim = int64(len(test.enc))
+
+ var dec qpackDecoder
+ got := []header{}
+ err := dec.decode(st2, func(itype indexType, name, value string) error {
+ got = append(got, header{itype, name, value})
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("decode: %v", err)
+ }
+ if !reflect.DeepEqual(got, test.want) {
+ t.Errorf("encoded: %x", test.enc)
+ t.Errorf("got headers:")
+ for _, h := range got {
+ t.Errorf(" %v: %q", h.name, h.value)
+ }
+ t.Errorf("want headers:")
+ for _, h := range test.want {
+ t.Errorf(" %v: %q", h.name, h.value)
+ }
+ }
+ })
+ }
+}
+
+func TestQPACKDecodeErrors(t *testing.T) {
+ // Many test cases here taken from Google QUICHE,
+ // quiche/quic/core/qpack/qpack_encoder_test.cc.
+ for _, test := range []struct {
+ name string
+ enc []byte
+ }{{
+ name: "literal entry empty name",
+ enc: unhex("00002003666f6f"),
+ }, {
+ name: "literal entry empty name and value",
+ enc: unhex("00002000"),
+ }, {
+ name: "name length too large for varint",
+ enc: unhex("000027ffffffffffffffffffff"),
+ }, {
+ name: "string literal too long",
+ enc: unhex("000027ffff7f"),
+ }, {
+ name: "value length too large for varint",
+ enc: unhex("000023666f6f7fffffffffffffffffffff"),
+ }, {
+ name: "value length too long",
+ enc: unhex("000023666f6f7fffff7f"),
+ }, {
+ name: "incomplete header block",
+ enc: unhex("00002366"),
+ }, {
+ name: "huffman name does not have eos prefix",
+ enc: unhex("00002f0125a849e95ba97d7e8925a849e95bb8e8b4bf"),
+ }, {
+ name: "huffman value does not have eos prefix",
+ enc: unhex("00002f0125a849e95ba97d7f8925a849e95bb8e8b4be"),
+ }, {
+ name: "huffman name eos prefix too long",
+ enc: unhex("00002f0225a849e95ba97d7fff8925a849e95bb8e8b4bf"),
+ }, {
+ name: "huffman value eos prefix too long",
+ enc: unhex("00002f0125a849e95ba97d7f8a25a849e95bb8e8b4bfff"),
+ }, {
+ name: "too high static table index",
+ enc: unhex("0000ff23ff24"),
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ st1, st2 := newStreamPair(t)
+ st1.Write(test.enc)
+ st1.Flush()
+
+ st2.lim = int64(len(test.enc))
+
+ var dec qpackDecoder
+ err := dec.decode(st2, func(itype indexType, name, value string) error {
+ return nil
+ })
+ if err == nil {
+ t.Errorf("encoded: %x", test.enc)
+ t.Fatalf("decode succeeded; want error")
+ }
+ })
+ }
+}
diff --git a/internal/http3/qpack_encode.go b/internal/http3/qpack_encode.go
new file mode 100644
index 0000000000..0f35e0c54f
--- /dev/null
+++ b/internal/http3/qpack_encode.go
@@ -0,0 +1,47 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+type qpackEncoder struct {
+ // The encoder has no state for now,
+ // but that'll change once we add dynamic table support.
+ //
+ // TODO: dynamic table support.
+}
+
+func (qe *qpackEncoder) init() {
+ staticTableOnce.Do(initStaticTableMaps)
+}
+
+// encode encodes a list of headers into a QPACK encoded field section.
+//
+// The headers func must produce the same headers on repeated calls,
+// although the order may vary.
+func (qe *qpackEncoder) encode(headers func(func(itype indexType, name, value string))) []byte {
+ // Encoded Field Section prefix.
+ //
+ // We don't yet use the dynamic table, so both values here are zero.
+ var b []byte
+ b = appendPrefixedInt(b, 0, 8, 0) // Required Insert Count
+ b = appendPrefixedInt(b, 0, 7, 0) // Delta Base
+
+ headers(func(itype indexType, name, value string) {
+ if itype == mayIndex {
+ if i, ok := staticTableByNameValue[tableEntry{name, value}]; ok {
+ b = appendIndexedFieldLine(b, staticTable, i)
+ return
+ }
+ }
+ if i, ok := staticTableByName[name]; ok {
+ b = appendLiteralFieldLineWithNameReference(b, staticTable, itype, i, value)
+ } else {
+ b = appendLiteralFieldLineWithLiteralName(b, itype, name, value)
+ }
+ })
+
+ return b
+}
diff --git a/internal/http3/qpack_encode_test.go b/internal/http3/qpack_encode_test.go
new file mode 100644
index 0000000000..f426d773a6
--- /dev/null
+++ b/internal/http3/qpack_encode_test.go
@@ -0,0 +1,126 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+)
+
+func TestQPACKEncode(t *testing.T) {
+ type header struct {
+ itype indexType
+ name, value string
+ }
+ // Many test cases here taken from Google QUICHE,
+ // quiche/quic/core/qpack/qpack_encoder_test.cc.
+ for _, test := range []struct {
+ name string
+ headers []header
+ want []byte
+ }{{
+ name: "empty",
+ headers: []header{},
+ want: unhex("0000"),
+ }, {
+ name: "empty name",
+ headers: []header{
+ {mayIndex, "", "foo"},
+ },
+ want: unhex("0000208294e7"),
+ }, {
+ name: "empty value",
+ headers: []header{
+ {mayIndex, "foo", ""},
+ },
+ want: unhex("00002a94e700"),
+ }, {
+ name: "empty name and value",
+ headers: []header{
+ {mayIndex, "", ""},
+ },
+ want: unhex("00002000"),
+ }, {
+ name: "simple",
+ headers: []header{
+ {mayIndex, "foo", "bar"},
+ },
+ want: unhex("00002a94e703626172"),
+ }, {
+ name: "multiple",
+ headers: []header{
+ {mayIndex, "foo", "bar"},
+ {mayIndex, "ZZZZZZZ", strings.Repeat("Z", 127)},
+ },
+ want: unhex("0000" + // prefix
+ // foo: bar
+ "2a94e703626172" +
+ // 7 octet long header name, the smallest number
+ // that does not fit on a 3-bit prefix.
+ "27005a5a5a5a5a5a5a" +
+ // 127 octet long header value, the smallest
+ // number that does not fit on a 7-bit prefix.
+ "7f005a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" +
+ "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" +
+ "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a" +
+ "5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a"),
+ }, {
+ name: "static table 1",
+ headers: []header{
+ {mayIndex, ":method", "GET"},
+ {mayIndex, "accept-encoding", "gzip, deflate, br"},
+ {mayIndex, "location", ""},
+ },
+ want: unhex("0000d1dfcc"),
+ }, {
+ name: "static table 2",
+ headers: []header{
+ {mayIndex, ":method", "POST"},
+ {mayIndex, "accept-encoding", "compress"},
+ {mayIndex, "location", "foo"},
+ },
+ want: unhex("0000d45f108621e9aec2a11f5c8294e7"),
+ }, {
+ name: "static table 3",
+ headers: []header{
+ {mayIndex, ":method", "TRACE"},
+ {mayIndex, "accept-encoding", ""},
+ },
+ want: unhex("00005f000554524143455f1000"),
+ }, {
+ name: "never indexed literal field line with name reference",
+ headers: []header{
+ {neverIndex, ":method", ""},
+ },
+ want: unhex("00007f0000"),
+ }, {
+ name: "never indexed literal field line with literal name",
+ headers: []header{
+ {neverIndex, "a", "b"},
+ },
+ want: unhex("000031610162"),
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ var enc qpackEncoder
+ enc.init()
+
+ got := enc.encode(func(f func(itype indexType, name, value string)) {
+ for _, h := range test.headers {
+ f(h.itype, h.name, h.value)
+ }
+ })
+ if !bytes.Equal(got, test.want) {
+ for _, h := range test.headers {
+ t.Logf("header %v: %q", h.name, h.value)
+ }
+ t.Errorf("got: %x", got)
+ t.Errorf("want: %x", test.want)
+ }
+ })
+ }
+}
diff --git a/internal/http3/qpack_static.go b/internal/http3/qpack_static.go
new file mode 100644
index 0000000000..cb0884eb7b
--- /dev/null
+++ b/internal/http3/qpack_static.go
@@ -0,0 +1,144 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import "sync"
+
+type tableEntry struct {
+ name string
+ value string
+}
+
+// staticTableEntry returns the static table entry with the given index.
+func staticTableEntry(index int64) (tableEntry, error) {
+ if index >= int64(len(staticTableEntries)) {
+ return tableEntry{}, errQPACKDecompressionFailed
+ }
+ return staticTableEntries[index], nil
+}
+
+func initStaticTableMaps() {
+ staticTableByName = make(map[string]int)
+ staticTableByNameValue = make(map[tableEntry]int)
+ for i, ent := range staticTableEntries {
+ if _, ok := staticTableByName[ent.name]; !ok {
+ staticTableByName[ent.name] = i
+ }
+ staticTableByNameValue[ent] = i
+ }
+}
+
+var (
+ staticTableOnce sync.Once
+ staticTableByName map[string]int
+ staticTableByNameValue map[tableEntry]int
+)
+
+// https://www.rfc-editor.org/rfc/rfc9204.html#appendix-A
+//
+// Note that this is different from the HTTP/2 static table.
+var staticTableEntries = [...]tableEntry{
+ 0: {":authority", ""},
+ 1: {":path", "/"},
+ 2: {"age", "0"},
+ 3: {"content-disposition", ""},
+ 4: {"content-length", "0"},
+ 5: {"cookie", ""},
+ 6: {"date", ""},
+ 7: {"etag", ""},
+ 8: {"if-modified-since", ""},
+ 9: {"if-none-match", ""},
+ 10: {"last-modified", ""},
+ 11: {"link", ""},
+ 12: {"location", ""},
+ 13: {"referer", ""},
+ 14: {"set-cookie", ""},
+ 15: {":method", "CONNECT"},
+ 16: {":method", "DELETE"},
+ 17: {":method", "GET"},
+ 18: {":method", "HEAD"},
+ 19: {":method", "OPTIONS"},
+ 20: {":method", "POST"},
+ 21: {":method", "PUT"},
+ 22: {":scheme", "http"},
+ 23: {":scheme", "https"},
+ 24: {":status", "103"},
+ 25: {":status", "200"},
+ 26: {":status", "304"},
+ 27: {":status", "404"},
+ 28: {":status", "503"},
+ 29: {"accept", "*/*"},
+ 30: {"accept", "application/dns-message"},
+ 31: {"accept-encoding", "gzip, deflate, br"},
+ 32: {"accept-ranges", "bytes"},
+ 33: {"access-control-allow-headers", "cache-control"},
+ 34: {"access-control-allow-headers", "content-type"},
+ 35: {"access-control-allow-origin", "*"},
+ 36: {"cache-control", "max-age=0"},
+ 37: {"cache-control", "max-age=2592000"},
+ 38: {"cache-control", "max-age=604800"},
+ 39: {"cache-control", "no-cache"},
+ 40: {"cache-control", "no-store"},
+ 41: {"cache-control", "public, max-age=31536000"},
+ 42: {"content-encoding", "br"},
+ 43: {"content-encoding", "gzip"},
+ 44: {"content-type", "application/dns-message"},
+ 45: {"content-type", "application/javascript"},
+ 46: {"content-type", "application/json"},
+ 47: {"content-type", "application/x-www-form-urlencoded"},
+ 48: {"content-type", "image/gif"},
+ 49: {"content-type", "image/jpeg"},
+ 50: {"content-type", "image/png"},
+ 51: {"content-type", "text/css"},
+ 52: {"content-type", "text/html; charset=utf-8"},
+ 53: {"content-type", "text/plain"},
+ 54: {"content-type", "text/plain;charset=utf-8"},
+ 55: {"range", "bytes=0-"},
+ 56: {"strict-transport-security", "max-age=31536000"},
+ 57: {"strict-transport-security", "max-age=31536000; includesubdomains"},
+ 58: {"strict-transport-security", "max-age=31536000; includesubdomains; preload"},
+ 59: {"vary", "accept-encoding"},
+ 60: {"vary", "origin"},
+ 61: {"x-content-type-options", "nosniff"},
+ 62: {"x-xss-protection", "1; mode=block"},
+ 63: {":status", "100"},
+ 64: {":status", "204"},
+ 65: {":status", "206"},
+ 66: {":status", "302"},
+ 67: {":status", "400"},
+ 68: {":status", "403"},
+ 69: {":status", "421"},
+ 70: {":status", "425"},
+ 71: {":status", "500"},
+ 72: {"accept-language", ""},
+ 73: {"access-control-allow-credentials", "FALSE"},
+ 74: {"access-control-allow-credentials", "TRUE"},
+ 75: {"access-control-allow-headers", "*"},
+ 76: {"access-control-allow-methods", "get"},
+ 77: {"access-control-allow-methods", "get, post, options"},
+ 78: {"access-control-allow-methods", "options"},
+ 79: {"access-control-expose-headers", "content-length"},
+ 80: {"access-control-request-headers", "content-type"},
+ 81: {"access-control-request-method", "get"},
+ 82: {"access-control-request-method", "post"},
+ 83: {"alt-svc", "clear"},
+ 84: {"authorization", ""},
+ 85: {"content-security-policy", "script-src 'none'; object-src 'none'; base-uri 'none'"},
+ 86: {"early-data", "1"},
+ 87: {"expect-ct", ""},
+ 88: {"forwarded", ""},
+ 89: {"if-range", ""},
+ 90: {"origin", ""},
+ 91: {"purpose", "prefetch"},
+ 92: {"server", ""},
+ 93: {"timing-allow-origin", "*"},
+ 94: {"upgrade-insecure-requests", "1"},
+ 95: {"user-agent", ""},
+ 96: {"x-forwarded-for", ""},
+ 97: {"x-frame-options", "deny"},
+ 98: {"x-frame-options", "sameorigin"},
+}
diff --git a/internal/http3/qpack_test.go b/internal/http3/qpack_test.go
new file mode 100644
index 0000000000..6e16511fc6
--- /dev/null
+++ b/internal/http3/qpack_test.go
@@ -0,0 +1,173 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestPrefixedInt(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, test := range []struct {
+ value int64
+ prefixLen uint8
+ encoded []byte
+ }{
+ // https://www.rfc-editor.org/rfc/rfc7541#appendix-C.1.1
+ {
+ value: 10,
+ prefixLen: 5,
+ encoded: []byte{
+ 0b_0000_1010,
+ },
+ },
+ // https://www.rfc-editor.org/rfc/rfc7541#appendix-C.1.2
+ {
+ value: 1337,
+ prefixLen: 5,
+ encoded: []byte{
+ 0b0001_1111,
+ 0b1001_1010,
+ 0b0000_1010,
+ },
+ },
+ // https://www.rfc-editor.org/rfc/rfc7541#appendix-C.1.3
+ {
+ value: 42,
+ prefixLen: 8,
+ encoded: []byte{
+ 0b0010_1010,
+ },
+ },
+ } {
+ highBitMask := ^((byte(1) << test.prefixLen) - 1)
+ for _, highBits := range []byte{
+ 0, highBitMask, 0b1010_1010 & highBitMask,
+ } {
+ gotEnc := appendPrefixedInt(nil, highBits, test.prefixLen, test.value)
+ wantEnc := append([]byte{}, test.encoded...)
+ wantEnc[0] |= highBits
+ if !bytes.Equal(gotEnc, wantEnc) {
+ t.Errorf("appendPrefixedInt(nil, 0b%08b, %v, %v) = {%x}, want {%x}",
+ highBits, test.prefixLen, test.value, gotEnc, wantEnc)
+ }
+
+ st1.Write(gotEnc)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ gotFirstByte, v, err := st2.readPrefixedInt(test.prefixLen)
+ if err != nil || gotFirstByte&highBitMask != highBits || v != test.value {
+ t.Errorf("st.readPrefixedInt(%v) = 0b%08b, %v, %v; want 0b%08b, %v, nil", test.prefixLen, gotFirstByte, v, err, highBits, test.value)
+ }
+ }
+ }
+}
+
+func TestPrefixedString(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, test := range []struct {
+ value string
+ prefixLen uint8
+ encoded []byte
+ }{
+ // https://www.rfc-editor.org/rfc/rfc7541#appendix-C.6.1
+ {
+ value: "302",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x82, // H bit + length 2
+ 0x64, 0x02,
+ },
+ },
+ {
+ value: "private",
+ prefixLen: 5,
+ encoded: []byte{
+ 0x25, // H bit + length 5
+ 0xae, 0xc3, 0x77, 0x1a, 0x4b,
+ },
+ },
+ {
+ value: "Mon, 21 Oct 2013 20:13:21 GMT",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x96, // H bit + length 22
+ 0xd0, 0x7a, 0xbe, 0x94, 0x10, 0x54, 0xd4, 0x44,
+ 0xa8, 0x20, 0x05, 0x95, 0x04, 0x0b, 0x81, 0x66,
+ 0xe0, 0x82, 0xa6, 0x2d, 0x1b, 0xff,
+ },
+ },
+ {
+ value: "https://www.example.com",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x91, // H bit + length 17
+ 0x9d, 0x29, 0xad, 0x17, 0x18, 0x63, 0xc7, 0x8f,
+ 0x0b, 0x97, 0xc8, 0xe9, 0xae, 0x82, 0xae, 0x43,
+ 0xd3,
+ },
+ },
+ // Not Huffman encoded (encoded size == unencoded size).
+ {
+ value: "a",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x01, // length 1
+ 0x61,
+ },
+ },
+ // Empty string.
+ {
+ value: "",
+ prefixLen: 7,
+ encoded: []byte{
+ 0x00, // length 0
+ },
+ },
+ } {
+ highBitMask := ^((byte(1) << (test.prefixLen + 1)) - 1)
+ for _, highBits := range []byte{
+ 0, highBitMask, 0b1010_1010 & highBitMask,
+ } {
+ gotEnc := appendPrefixedString(nil, highBits, test.prefixLen, test.value)
+ wantEnc := append([]byte{}, test.encoded...)
+ wantEnc[0] |= highBits
+ if !bytes.Equal(gotEnc, wantEnc) {
+ t.Errorf("appendPrefixedString(nil, 0b%08b, %v, %v) = {%x}, want {%x}",
+ highBits, test.prefixLen, test.value, gotEnc, wantEnc)
+ }
+
+ st1.Write(gotEnc)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ gotFirstByte, v, err := st2.readPrefixedString(test.prefixLen)
+ if err != nil || gotFirstByte&highBitMask != highBits || v != test.value {
+ t.Errorf("st.readPrefixedInt(%v) = 0b%08b, %q, %v; want 0b%08b, %q, nil", test.prefixLen, gotFirstByte, v, err, highBits, test.value)
+ }
+ }
+ }
+}
+
+func TestHuffmanDecodingFailure(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ st1.Write([]byte{
+ 0x82, // H bit + length 4
+ 0b_1111_1111,
+ 0b_1111_1111,
+ 0b_1111_1111,
+ 0b_1111_1111,
+ })
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ if b, v, err := st2.readPrefixedString(7); err == nil {
+ t.Fatalf("readPrefixedString(7) = %x, %v, nil; want error", b, v)
+ }
+}
diff --git a/internal/http3/quic.go b/internal/http3/quic.go
new file mode 100644
index 0000000000..6d2b120094
--- /dev/null
+++ b/internal/http3/quic.go
@@ -0,0 +1,42 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "crypto/tls"
+
+ "golang.org/x/net/quic"
+)
+
+func initConfig(config *quic.Config) *quic.Config {
+ if config == nil {
+ config = &quic.Config{}
+ }
+
+ // maybeCloneTLSConfig clones the user-provided tls.Config (but only once)
+ // prior to us modifying it.
+ needCloneTLSConfig := true
+ maybeCloneTLSConfig := func() *tls.Config {
+ if needCloneTLSConfig {
+ config.TLSConfig = config.TLSConfig.Clone()
+ needCloneTLSConfig = false
+ }
+ return config.TLSConfig
+ }
+
+ if config.TLSConfig == nil {
+ config.TLSConfig = &tls.Config{}
+ needCloneTLSConfig = false
+ }
+ if config.TLSConfig.MinVersion == 0 {
+ maybeCloneTLSConfig().MinVersion = tls.VersionTLS13
+ }
+ if config.TLSConfig.NextProtos == nil {
+ maybeCloneTLSConfig().NextProtos = []string{"h3"}
+ }
+ return config
+}
diff --git a/internal/http3/quic_test.go b/internal/http3/quic_test.go
new file mode 100644
index 0000000000..bc3b110fe9
--- /dev/null
+++ b/internal/http3/quic_test.go
@@ -0,0 +1,234 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "net"
+ "net/netip"
+ "runtime"
+ "sync"
+ "testing"
+ "time"
+
+ "golang.org/x/net/internal/gate"
+ "golang.org/x/net/internal/testcert"
+ "golang.org/x/net/quic"
+)
+
+// newLocalQUICEndpoint returns a QUIC Endpoint listening on localhost.
+func newLocalQUICEndpoint(t *testing.T) *quic.Endpoint {
+ t.Helper()
+ switch runtime.GOOS {
+ case "plan9":
+ t.Skipf("ReadMsgUDP not supported on %s", runtime.GOOS)
+ }
+ conf := &quic.Config{
+ TLSConfig: testTLSConfig,
+ }
+ e, err := quic.Listen("udp", "127.0.0.1:0", conf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ e.Close(context.Background())
+ })
+ return e
+}
+
+// newQUICEndpointPair returns two QUIC endpoints on the same test network.
+func newQUICEndpointPair(t testing.TB) (e1, e2 *quic.Endpoint) {
+ config := &quic.Config{
+ TLSConfig: testTLSConfig,
+ }
+ tn := &testNet{}
+ e1 = tn.newQUICEndpoint(t, config)
+ e2 = tn.newQUICEndpoint(t, config)
+ return e1, e2
+}
+
+// newQUICStreamPair returns the two sides of a bidirectional QUIC stream.
+func newQUICStreamPair(t testing.TB) (s1, s2 *quic.Stream) {
+ t.Helper()
+ config := &quic.Config{
+ TLSConfig: testTLSConfig,
+ }
+ e1, e2 := newQUICEndpointPair(t)
+ c1, err := e1.Dial(context.Background(), "udp", e2.LocalAddr().String(), config)
+ if err != nil {
+ t.Fatal(err)
+ }
+ c2, err := e2.Accept(context.Background())
+ if err != nil {
+ t.Fatal(err)
+ }
+ s1, err = c1.NewStream(context.Background())
+ if err != nil {
+ t.Fatal(err)
+ }
+ s1.Flush()
+ s2, err = c2.AcceptStream(context.Background())
+ if err != nil {
+ t.Fatal(err)
+ }
+ return s1, s2
+}
+
+// A testNet is a fake network of net.PacketConns.
+type testNet struct {
+ mu sync.Mutex
+ conns map[netip.AddrPort]*testPacketConn
+}
+
+// newPacketConn returns a new PacketConn with a unique source address.
+func (tn *testNet) newPacketConn() *testPacketConn {
+ tn.mu.Lock()
+ defer tn.mu.Unlock()
+ if tn.conns == nil {
+ tn.conns = make(map[netip.AddrPort]*testPacketConn)
+ }
+ localAddr := netip.AddrPortFrom(
+ netip.AddrFrom4([4]byte{
+ 127, 0, 0, byte(len(tn.conns)),
+ }),
+ 443)
+ tc := &testPacketConn{
+ tn: tn,
+ localAddr: localAddr,
+ gate: gate.New(false),
+ }
+ tn.conns[localAddr] = tc
+ return tc
+}
+
+func (tn *testNet) newQUICEndpoint(t testing.TB, config *quic.Config) *quic.Endpoint {
+ t.Helper()
+ pc := tn.newPacketConn()
+ e, err := quic.NewEndpoint(pc, config)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ e.Close(t.Context())
+ })
+ return e
+}
+
+// connForAddr returns the conn with the given source address.
+func (tn *testNet) connForAddr(srcAddr netip.AddrPort) *testPacketConn {
+ tn.mu.Lock()
+ defer tn.mu.Unlock()
+ return tn.conns[srcAddr]
+}
+
+// A testPacketConn is a net.PacketConn on a testNet fake network.
+type testPacketConn struct {
+ tn *testNet
+ localAddr netip.AddrPort
+
+ gate gate.Gate
+ queue []testPacket
+ closed bool
+}
+
+type testPacket struct {
+ b []byte
+ src netip.AddrPort
+}
+
+func (tc *testPacketConn) unlock() {
+ tc.gate.Unlock(tc.closed || len(tc.queue) > 0)
+}
+
+func (tc *testPacketConn) ReadFrom(p []byte) (n int, srcAddr net.Addr, err error) {
+ if err := tc.gate.WaitAndLock(context.Background()); err != nil {
+ return 0, nil, err
+ }
+ defer tc.unlock()
+ if tc.closed {
+ return 0, nil, net.ErrClosed
+ }
+ n = copy(p, tc.queue[0].b)
+ srcAddr = net.UDPAddrFromAddrPort(tc.queue[0].src)
+ tc.queue = tc.queue[1:]
+ return n, srcAddr, nil
+}
+
+func (tc *testPacketConn) WriteTo(p []byte, dstAddr net.Addr) (n int, err error) {
+ tc.gate.Lock()
+ closed := tc.closed
+ tc.unlock()
+ if closed {
+ return 0, net.ErrClosed
+ }
+
+ ap, err := addrPortFromAddr(dstAddr)
+ if err != nil {
+ return 0, err
+ }
+ dst := tc.tn.connForAddr(ap)
+ if dst == nil {
+ return len(p), nil // sent into the void
+ }
+ dst.gate.Lock()
+ defer dst.unlock()
+ dst.queue = append(dst.queue, testPacket{
+ b: bytes.Clone(p),
+ src: tc.localAddr,
+ })
+ return len(p), nil
+}
+
+func (tc *testPacketConn) Close() error {
+ tc.tn.mu.Lock()
+ tc.tn.conns[tc.localAddr] = nil
+ tc.tn.mu.Unlock()
+
+ tc.gate.Lock()
+ defer tc.unlock()
+ tc.closed = true
+ tc.queue = nil
+ return nil
+}
+
+func (tc *testPacketConn) LocalAddr() net.Addr {
+ return net.UDPAddrFromAddrPort(tc.localAddr)
+}
+
+func (tc *testPacketConn) SetDeadline(time.Time) error { panic("unimplemented") }
+func (tc *testPacketConn) SetReadDeadline(time.Time) error { panic("unimplemented") }
+func (tc *testPacketConn) SetWriteDeadline(time.Time) error { panic("unimplemented") }
+
+func addrPortFromAddr(addr net.Addr) (netip.AddrPort, error) {
+ switch a := addr.(type) {
+ case *net.UDPAddr:
+ return a.AddrPort(), nil
+ }
+ return netip.ParseAddrPort(addr.String())
+}
+
+var testTLSConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ CipherSuites: []uint16{
+ tls.TLS_AES_128_GCM_SHA256,
+ tls.TLS_AES_256_GCM_SHA384,
+ tls.TLS_CHACHA20_POLY1305_SHA256,
+ },
+ MinVersion: tls.VersionTLS13,
+ Certificates: []tls.Certificate{testCert},
+ NextProtos: []string{"h3"},
+}
+
+var testCert = func() tls.Certificate {
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+ if err != nil {
+ panic(err)
+ }
+ return cert
+}()
diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go
new file mode 100644
index 0000000000..bf55a13159
--- /dev/null
+++ b/internal/http3/roundtrip.go
@@ -0,0 +1,347 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "errors"
+ "io"
+ "net/http"
+ "strconv"
+ "sync"
+
+ "golang.org/x/net/internal/httpcommon"
+)
+
+type roundTripState struct {
+ cc *ClientConn
+ st *stream
+
+ // Request body, provided by the caller.
+ onceCloseReqBody sync.Once
+ reqBody io.ReadCloser
+
+ reqBodyWriter bodyWriter
+
+ // Response.Body, provided to the caller.
+ respBody bodyReader
+
+ errOnce sync.Once
+ err error
+}
+
+// abort terminates the RoundTrip.
+// It returns the first fatal error encountered by the RoundTrip call.
+func (rt *roundTripState) abort(err error) error {
+ rt.errOnce.Do(func() {
+ rt.err = err
+ switch e := err.(type) {
+ case *connectionError:
+ rt.cc.abort(e)
+ case *streamError:
+ rt.st.stream.CloseRead()
+ rt.st.stream.Reset(uint64(e.code))
+ default:
+ rt.st.stream.CloseRead()
+ rt.st.stream.Reset(uint64(errH3NoError))
+ }
+ })
+ return rt.err
+}
+
+// closeReqBody closes the Request.Body, at most once.
+func (rt *roundTripState) closeReqBody() {
+ if rt.reqBody != nil {
+ rt.onceCloseReqBody.Do(func() {
+ rt.reqBody.Close()
+ })
+ }
+}
+
+// RoundTrip sends a request on the connection.
+func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) {
+ // Each request gets its own QUIC stream.
+ st, err := newConnStream(req.Context(), cc.qconn, streamTypeRequest)
+ if err != nil {
+ return nil, err
+ }
+ rt := &roundTripState{
+ cc: cc,
+ st: st,
+ }
+ defer func() {
+ if err != nil {
+ err = rt.abort(err)
+ }
+ }()
+
+ // Cancel reads/writes on the stream when the request expires.
+ st.stream.SetReadContext(req.Context())
+ st.stream.SetWriteContext(req.Context())
+
+ contentLength := actualContentLength(req)
+
+ var encr httpcommon.EncodeHeadersResult
+ headers := cc.enc.encode(func(yield func(itype indexType, name, value string)) {
+ encr, err = httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{
+ Request: httpcommon.Request{
+ URL: req.URL,
+ Method: req.Method,
+ Host: req.Host,
+ Header: req.Header,
+ Trailer: req.Trailer,
+ ActualContentLength: contentLength,
+ },
+ AddGzipHeader: false, // TODO: add when appropriate
+ PeerMaxHeaderListSize: 0,
+ DefaultUserAgent: "Go-http-client/3",
+ }, func(name, value string) {
+ // Issue #71374: Consider supporting never-indexed fields.
+ yield(mayIndex, name, value)
+ })
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Write the HEADERS frame.
+ st.writeVarint(int64(frameTypeHeaders))
+ st.writeVarint(int64(len(headers)))
+ st.Write(headers)
+ if err := st.Flush(); err != nil {
+ return nil, err
+ }
+
+ if encr.HasBody {
+ // TODO: Defer sending the request body when "Expect: 100-continue" is set.
+ rt.reqBody = req.Body
+ rt.reqBodyWriter.st = st
+ rt.reqBodyWriter.remain = contentLength
+ rt.reqBodyWriter.flush = true
+ rt.reqBodyWriter.name = "request"
+ go copyRequestBody(rt)
+ }
+
+ // Read the response headers.
+ for {
+ ftype, err := st.readFrameHeader()
+ if err != nil {
+ return nil, err
+ }
+ switch ftype {
+ case frameTypeHeaders:
+ statusCode, h, err := cc.handleHeaders(st)
+ if err != nil {
+ return nil, err
+ }
+
+ if statusCode >= 100 && statusCode < 199 {
+ // TODO: Handle 1xx responses.
+ continue
+ }
+
+ // We have the response headers.
+ // Set up the response and return it to the caller.
+ contentLength, err := parseResponseContentLength(req.Method, statusCode, h)
+ if err != nil {
+ return nil, err
+ }
+ rt.respBody.st = st
+ rt.respBody.remain = contentLength
+ resp := &http.Response{
+ Proto: "HTTP/3.0",
+ ProtoMajor: 3,
+ Header: h,
+ StatusCode: statusCode,
+ Status: strconv.Itoa(statusCode) + " " + http.StatusText(statusCode),
+ ContentLength: contentLength,
+ Body: (*transportResponseBody)(rt),
+ }
+ // TODO: Automatic Content-Type: gzip decoding.
+ return resp, nil
+ case frameTypePushPromise:
+ if err := cc.handlePushPromise(st); err != nil {
+ return nil, err
+ }
+ default:
+ if err := st.discardUnknownFrame(ftype); err != nil {
+ return nil, err
+ }
+ }
+ }
+}
+
+// actualContentLength returns a sanitized version of req.ContentLength,
+// where 0 actually means zero (not unknown) and -1 means unknown.
+func actualContentLength(req *http.Request) int64 {
+ if req.Body == nil || req.Body == http.NoBody {
+ return 0
+ }
+ if req.ContentLength != 0 {
+ return req.ContentLength
+ }
+ return -1
+}
+
+func copyRequestBody(rt *roundTripState) {
+ defer rt.closeReqBody()
+ _, err := io.Copy(&rt.reqBodyWriter, rt.reqBody)
+ if closeErr := rt.reqBodyWriter.Close(); err == nil {
+ err = closeErr
+ }
+ if err != nil {
+ // Something went wrong writing the body.
+ rt.abort(err)
+ } else {
+ // We wrote the whole body.
+ rt.st.stream.CloseWrite()
+ }
+}
+
+// transportResponseBody is the Response.Body returned by RoundTrip.
+type transportResponseBody roundTripState
+
+// Read is Response.Body.Read.
+func (b *transportResponseBody) Read(p []byte) (n int, err error) {
+ return b.respBody.Read(p)
+}
+
+var errRespBodyClosed = errors.New("response body closed")
+
+// Close is Response.Body.Close.
+// Closing the response body is how the caller signals that they're done with a request.
+func (b *transportResponseBody) Close() error {
+ rt := (*roundTripState)(b)
+ // Close the request body, which should wake up copyRequestBody if it's
+ // currently blocked reading the body.
+ rt.closeReqBody()
+ // Close the request stream, since we're done with the request.
+ // Reset closes the sending half of the stream.
+ rt.st.stream.Reset(uint64(errH3NoError))
+ // respBody.Close is responsible for closing the receiving half.
+ err := rt.respBody.Close()
+ if err == nil {
+ err = errRespBodyClosed
+ }
+ err = rt.abort(err)
+ if err == errRespBodyClosed {
+ // No other errors occurred before closing Response.Body,
+ // so consider this a successful request.
+ return nil
+ }
+ return err
+}
+
+func parseResponseContentLength(method string, statusCode int, h http.Header) (int64, error) {
+ clens := h["Content-Length"]
+ if len(clens) == 0 {
+ return -1, nil
+ }
+
+ // We allow duplicate Content-Length headers,
+ // but only if they all have the same value.
+ for _, v := range clens[1:] {
+ if clens[0] != v {
+ return -1, &streamError{errH3MessageError, "mismatching Content-Length headers"}
+ }
+ }
+
+ // "A server MUST NOT send a Content-Length header field in any response
+ // with a status code of 1xx (Informational) or 204 (No Content).
+ // A server MUST NOT send a Content-Length header field in any 2xx (Successful)
+ // response to a CONNECT request [...]"
+ // https://www.rfc-editor.org/rfc/rfc9110#section-8.6-8
+ if (statusCode >= 100 && statusCode < 200) ||
+ statusCode == 204 ||
+ (method == "CONNECT" && statusCode >= 200 && statusCode < 300) {
+ // This is a protocol violation, but a fairly harmless one.
+ // Just ignore the header.
+ return -1, nil
+ }
+
+ contentLen, err := strconv.ParseUint(clens[0], 10, 63)
+ if err != nil {
+ return -1, &streamError{errH3MessageError, "invalid Content-Length header"}
+ }
+ return int64(contentLen), nil
+}
+
+func (cc *ClientConn) handleHeaders(st *stream) (statusCode int, h http.Header, err error) {
+ haveStatus := false
+ cookie := ""
+ // Issue #71374: Consider tracking the never-indexed status of headers
+ // with the N bit set in their QPACK encoding.
+ err = cc.dec.decode(st, func(_ indexType, name, value string) error {
+ switch {
+ case name == ":status":
+ if haveStatus {
+ return &streamError{errH3MessageError, "duplicate :status"}
+ }
+ haveStatus = true
+ statusCode, err = strconv.Atoi(value)
+ if err != nil {
+ return &streamError{errH3MessageError, "invalid :status"}
+ }
+ case name[0] == ':':
+ // "Endpoints MUST treat a request or response
+ // that contains undefined or invalid
+ // pseudo-header fields as malformed."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3-3
+ return &streamError{errH3MessageError, "undefined pseudo-header"}
+ case name == "cookie":
+ // "If a decompressed field section contains multiple cookie field lines,
+ // these MUST be concatenated into a single byte string [...]"
+ // using the two-byte delimiter of "; "''
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2.1-2
+ if cookie == "" {
+ cookie = value
+ } else {
+ cookie += "; " + value
+ }
+ default:
+ if h == nil {
+ h = make(http.Header)
+ }
+ // TODO: Use a per-connection canonicalization cache as we do in HTTP/2.
+ // Maybe we could put this in the QPACK decoder and have it deliver
+ // pre-canonicalized headers to us here?
+ cname := httpcommon.CanonicalHeader(name)
+ // TODO: Consider using a single []string slice for all headers,
+ // as we do in the HTTP/1 and HTTP/2 cases.
+ // This is a bit tricky, since we don't know the number of headers
+ // at the start of decoding. Perhaps it's worth doing a two-pass decode,
+ // or perhaps we should just allocate header value slices in
+ // reasonably-sized chunks.
+ h[cname] = append(h[cname], value)
+ }
+ return nil
+ })
+ if !haveStatus {
+ // "[The :status] pseudo-header field MUST be included in all responses [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.3.2-1
+ err = errH3MessageError
+ }
+ if cookie != "" {
+ if h == nil {
+ h = make(http.Header)
+ }
+ h["Cookie"] = []string{cookie}
+ }
+ if err := st.endFrame(); err != nil {
+ return 0, nil, err
+ }
+ return statusCode, h, err
+}
+
+func (cc *ClientConn) handlePushPromise(st *stream) error {
+ // "A client MUST treat receipt of a PUSH_PROMISE frame that contains a
+ // larger push ID than the client has advertised as a connection error of H3_ID_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.5-5
+ return &connectionError{
+ code: errH3IDError,
+ message: "PUSH_PROMISE received when no MAX_PUSH_ID has been sent",
+ }
+}
diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go
new file mode 100644
index 0000000000..acd8613d0e
--- /dev/null
+++ b/internal/http3/roundtrip_test.go
@@ -0,0 +1,354 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "net/http"
+ "testing"
+ "testing/synctest"
+
+ "golang.org/x/net/quic"
+)
+
+func TestRoundTripSimple(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ req.Header["User-Agent"] = nil
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(http.Header{
+ ":authority": []string{"example.tld"},
+ ":method": []string{"GET"},
+ ":path": []string{"/"},
+ ":scheme": []string{"https"},
+ })
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ "x-some-header": []string{"value"},
+ })
+ rt.wantStatus(200)
+ rt.wantHeaders(http.Header{
+ "X-Some-Header": []string{"value"},
+ })
+ })
+}
+
+func TestRoundTripWithBadHeaders(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ req.Header["Invalid\nHeader"] = []string{"x"}
+ rt := tc.roundTrip(req)
+ rt.wantError("RoundTrip fails when request contains invalid headers")
+ })
+}
+
+func TestRoundTripWithUnknownFrame(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+
+ // Write an unknown frame type before the response HEADERS.
+ data := "frame content"
+ st.writeVarint(0x1f + 0x21) // reserved frame type
+ st.writeVarint(int64(len(data))) // size
+ st.Write([]byte(data))
+
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ })
+ rt.wantStatus(200)
+ })
+}
+
+func TestRoundTripWithInvalidPushPromise(t *testing.T) {
+ // "A client MUST treat receipt of a PUSH_PROMISE frame that contains
+ // a larger push ID than the client has advertised as a connection error of H3_ID_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.5-5
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+
+ // Write a PUSH_PROMISE frame.
+ // Since the client hasn't indicated willingness to accept pushes,
+ // this is a connection error.
+ st.writePushPromise(0, http.Header{
+ ":path": []string{"/foo"},
+ })
+ rt.wantError("RoundTrip fails after receiving invalid PUSH_PROMISE")
+ tc.wantClosed(
+ "push ID exceeds client's MAX_PUSH_ID",
+ errH3IDError,
+ )
+ })
+}
+
+func TestRoundTripResponseContentLength(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ respHeader http.Header
+ wantContentLength int64
+ wantError bool
+ }{{
+ name: "valid",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{"100"},
+ },
+ wantContentLength: 100,
+ }, {
+ name: "absent",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ },
+ wantContentLength: -1,
+ }, {
+ name: "unparseable",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{"1 1"},
+ },
+ wantError: true,
+ }, {
+ name: "duplicated",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{"500", "500", "500"},
+ },
+ wantContentLength: 500,
+ }, {
+ name: "inconsistent",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ "content-length": []string{"1", "2"},
+ },
+ wantError: true,
+ }, {
+ // 204 responses aren't allowed to contain a Content-Length header.
+ // We just ignore it.
+ name: "204",
+ respHeader: http.Header{
+ ":status": []string{"204"},
+ "content-length": []string{"100"},
+ },
+ wantContentLength: -1,
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(test.respHeader)
+ if test.wantError {
+ rt.wantError("invalid content-length in response")
+ return
+ }
+ if got, want := rt.response().ContentLength, test.wantContentLength; got != want {
+ t.Errorf("Response.ContentLength = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+func TestRoundTripMalformedResponses(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ respHeader http.Header
+ }{{
+ name: "duplicate :status",
+ respHeader: http.Header{
+ ":status": []string{"200", "204"},
+ },
+ }, {
+ name: "unparseable :status",
+ respHeader: http.Header{
+ ":status": []string{"frogpants"},
+ },
+ }, {
+ name: "undefined pseudo-header",
+ respHeader: http.Header{
+ ":status": []string{"200"},
+ ":unknown": []string{"x"},
+ },
+ }, {
+ name: "no :status",
+ respHeader: http.Header{},
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(test.respHeader)
+ rt.wantError("malformed response")
+ })
+ }
+}
+
+func TestRoundTripCrumbledCookiesInResponse(t *testing.T) {
+ // "If a decompressed field section contains multiple cookie field lines,
+ // these MUST be concatenated into a single byte string [...]"
+ // using the two-byte delimiter of "; "''
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2.1-2
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", nil)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ "cookie": []string{"a=1", "b=2; c=3", "d=4"},
+ })
+ rt.wantStatus(200)
+ rt.wantHeaders(http.Header{
+ "Cookie": []string{"a=1; b=2; c=3; d=4"},
+ })
+ })
+}
+
+func TestRoundTripRequestBodySent(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ bodyr, bodyw := io.Pipe()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", bodyr)
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+ st.wantHeaders(nil)
+
+ bodyw.Write([]byte{0, 1, 2, 3, 4})
+ st.wantData([]byte{0, 1, 2, 3, 4})
+
+ bodyw.Write([]byte{5, 6, 7})
+ st.wantData([]byte{5, 6, 7})
+
+ bodyw.Close()
+ st.wantClosed("request body sent")
+
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ })
+ rt.wantStatus(200)
+ rt.response().Body.Close()
+ })
+}
+
+func TestRoundTripRequestBodyErrors(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ body io.Reader
+ contentLength int64
+ }{{
+ name: "too short",
+ contentLength: 10,
+ body: bytes.NewReader([]byte{0, 1, 2, 3, 4}),
+ }, {
+ name: "too long",
+ contentLength: 5,
+ body: bytes.NewReader([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}),
+ }, {
+ name: "read error",
+ body: io.MultiReader(
+ bytes.NewReader([]byte{0, 1, 2, 3, 4}),
+ &testReader{
+ readFunc: func([]byte) (int, error) {
+ return 0, errors.New("read error")
+ },
+ },
+ ),
+ }} {
+ runSynctestSubtest(t, test.name, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ req, _ := http.NewRequest("GET", "https://example.tld/", test.body)
+ req.ContentLength = test.contentLength
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+
+ // The Transport should send some number of frames before detecting an
+ // error in the request body and aborting the request.
+ synctest.Wait()
+ for {
+ _, err := st.readFrameHeader()
+ if err != nil {
+ var code quic.StreamErrorCode
+ if !errors.As(err, &code) {
+ t.Fatalf("request stream closed with error %v: want QUIC stream error", err)
+ }
+ break
+ }
+ if err := st.discardFrame(); err != nil {
+ t.Fatalf("discardFrame: %v", err)
+ }
+ }
+
+ // RoundTrip returns with an error.
+ rt.wantError("request fails due to body error")
+ })
+ }
+}
+
+func TestRoundTripRequestBodyErrorAfterHeaders(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+
+ bodyr, bodyw := io.Pipe()
+ req, _ := http.NewRequest("GET", "https://example.tld/", bodyr)
+ req.ContentLength = 10
+ rt := tc.roundTrip(req)
+ st := tc.wantStream(streamTypeRequest)
+
+ // Server sends response headers, and RoundTrip returns.
+ // The request body hasn't been sent yet.
+ st.wantHeaders(nil)
+ st.writeHeaders(http.Header{
+ ":status": []string{"200"},
+ })
+ rt.wantStatus(200)
+
+ // Write too many bytes to the request body, triggering a request error.
+ bodyw.Write(make([]byte, req.ContentLength+1))
+
+ //io.Copy(io.Discard, st)
+ st.wantError(quic.StreamErrorCode(errH3InternalError))
+
+ if err := rt.response().Body.Close(); err == nil {
+ t.Fatalf("Response.Body.Close() = %v, want error", err)
+ }
+ })
+}
diff --git a/internal/http3/server.go b/internal/http3/server.go
new file mode 100644
index 0000000000..ca93c5298a
--- /dev/null
+++ b/internal/http3/server.go
@@ -0,0 +1,172 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "context"
+ "net/http"
+ "sync"
+
+ "golang.org/x/net/quic"
+)
+
+// A Server is an HTTP/3 server.
+// The zero value for Server is a valid server.
+type Server struct {
+ // Handler to invoke for requests, http.DefaultServeMux if nil.
+ Handler http.Handler
+
+ // Config is the QUIC configuration used by the server.
+ // The Config may be nil.
+ //
+ // ListenAndServe may clone and modify the Config.
+ // The Config must not be modified after calling ListenAndServe.
+ Config *quic.Config
+
+ initOnce sync.Once
+}
+
+func (s *Server) init() {
+ s.initOnce.Do(func() {
+ s.Config = initConfig(s.Config)
+ if s.Handler == nil {
+ s.Handler = http.DefaultServeMux
+ }
+ })
+}
+
+// ListenAndServe listens on the UDP network address addr
+// and then calls Serve to handle requests on incoming connections.
+func (s *Server) ListenAndServe(addr string) error {
+ s.init()
+ e, err := quic.Listen("udp", addr, s.Config)
+ if err != nil {
+ return err
+ }
+ return s.Serve(e)
+}
+
+// Serve accepts incoming connections on the QUIC endpoint e,
+// and handles requests from those connections.
+func (s *Server) Serve(e *quic.Endpoint) error {
+ s.init()
+ for {
+ qconn, err := e.Accept(context.Background())
+ if err != nil {
+ return err
+ }
+ go newServerConn(qconn)
+ }
+}
+
+type serverConn struct {
+ qconn *quic.Conn
+
+ genericConn // for handleUnidirectionalStream
+ enc qpackEncoder
+ dec qpackDecoder
+}
+
+func newServerConn(qconn *quic.Conn) {
+ sc := &serverConn{
+ qconn: qconn,
+ }
+ sc.enc.init()
+
+ // Create control stream and send SETTINGS frame.
+ // TODO: Time out on creating stream.
+ controlStream, err := newConnStream(context.Background(), sc.qconn, streamTypeControl)
+ if err != nil {
+ return
+ }
+ controlStream.writeSettings()
+ controlStream.Flush()
+
+ sc.acceptStreams(sc.qconn, sc)
+}
+
+func (sc *serverConn) handleControlStream(st *stream) error {
+ // "A SETTINGS frame MUST be sent as the first frame of each control stream [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2
+ if err := st.readSettings(func(settingsType, settingsValue int64) error {
+ switch settingsType {
+ case settingsMaxFieldSectionSize:
+ _ = settingsValue // TODO
+ case settingsQPACKMaxTableCapacity:
+ _ = settingsValue // TODO
+ case settingsQPACKBlockedStreams:
+ _ = settingsValue // TODO
+ default:
+ // Unknown settings types are ignored.
+ }
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ for {
+ ftype, err := st.readFrameHeader()
+ if err != nil {
+ return err
+ }
+ switch ftype {
+ case frameTypeCancelPush:
+ // "If a server receives a CANCEL_PUSH frame for a push ID
+ // that has not yet been mentioned by a PUSH_PROMISE frame,
+ // this MUST be treated as a connection error of type H3_ID_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-8
+ return &connectionError{
+ code: errH3IDError,
+ message: "CANCEL_PUSH for unsent push ID",
+ }
+ case frameTypeGoaway:
+ return errH3NoError
+ default:
+ // Unknown frames are ignored.
+ if err := st.discardUnknownFrame(ftype); err != nil {
+ return err
+ }
+ }
+ }
+}
+
+func (sc *serverConn) handleEncoderStream(*stream) error {
+ // TODO
+ return nil
+}
+
+func (sc *serverConn) handleDecoderStream(*stream) error {
+ // TODO
+ return nil
+}
+
+func (sc *serverConn) handlePushStream(*stream) error {
+ // "[...] if a server receives a client-initiated push stream,
+ // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3
+ return &connectionError{
+ code: errH3StreamCreationError,
+ message: "client created push stream",
+ }
+}
+
+func (sc *serverConn) handleRequestStream(st *stream) error {
+ // TODO
+ return nil
+}
+
+// abort closes the connection with an error.
+func (sc *serverConn) abort(err error) {
+ if e, ok := err.(*connectionError); ok {
+ sc.qconn.Abort(&quic.ApplicationError{
+ Code: uint64(e.code),
+ Reason: e.message,
+ })
+ } else {
+ sc.qconn.Abort(err)
+ }
+}
diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go
new file mode 100644
index 0000000000..8e727d2512
--- /dev/null
+++ b/internal/http3/server_test.go
@@ -0,0 +1,110 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "net/netip"
+ "testing"
+ "testing/synctest"
+
+ "golang.org/x/net/internal/quic/quicwire"
+ "golang.org/x/net/quic"
+)
+
+func TestServerReceivePushStream(t *testing.T) {
+ // "[...] if a server receives a client-initiated push stream,
+ // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3
+ runSynctest(t, func(t testing.TB) {
+ ts := newTestServer(t)
+ tc := ts.connect()
+ tc.newStream(streamTypePush)
+ tc.wantClosed("invalid client-created push stream", errH3StreamCreationError)
+ })
+}
+
+func TestServerCancelPushForUnsentPromise(t *testing.T) {
+ runSynctest(t, func(t testing.TB) {
+ ts := newTestServer(t)
+ tc := ts.connect()
+ tc.greet()
+
+ const pushID = 100
+ tc.control.writeVarint(int64(frameTypeCancelPush))
+ tc.control.writeVarint(int64(quicwire.SizeVarint(pushID)))
+ tc.control.writeVarint(pushID)
+ tc.control.Flush()
+
+ tc.wantClosed("client canceled never-sent push ID", errH3IDError)
+ })
+}
+
+type testServer struct {
+ t testing.TB
+ s *Server
+ tn testNet
+ *testQUICEndpoint
+
+ addr netip.AddrPort
+}
+
+type testQUICEndpoint struct {
+ t testing.TB
+ e *quic.Endpoint
+}
+
+func (te *testQUICEndpoint) dial() {
+}
+
+type testServerConn struct {
+ ts *testServer
+
+ *testQUICConn
+ control *testQUICStream
+}
+
+func newTestServer(t testing.TB) *testServer {
+ t.Helper()
+ ts := &testServer{
+ t: t,
+ s: &Server{
+ Config: &quic.Config{
+ TLSConfig: testTLSConfig,
+ },
+ },
+ }
+ e := ts.tn.newQUICEndpoint(t, ts.s.Config)
+ ts.addr = e.LocalAddr()
+ go ts.s.Serve(e)
+ return ts
+}
+
+func (ts *testServer) connect() *testServerConn {
+ ts.t.Helper()
+ config := &quic.Config{TLSConfig: testTLSConfig}
+ e := ts.tn.newQUICEndpoint(ts.t, nil)
+ qconn, err := e.Dial(ts.t.Context(), "udp", ts.addr.String(), config)
+ if err != nil {
+ ts.t.Fatal(err)
+ }
+ tc := &testServerConn{
+ ts: ts,
+ testQUICConn: newTestQUICConn(ts.t, qconn),
+ }
+ synctest.Wait()
+ return tc
+}
+
+// greet performs initial connection handshaking with the server.
+func (tc *testServerConn) greet() {
+ // Client creates a control stream.
+ tc.control = tc.newStream(streamTypeControl)
+ tc.control.writeVarint(int64(frameTypeSettings))
+ tc.control.writeVarint(0) // size
+ tc.control.Flush()
+ synctest.Wait()
+}
diff --git a/internal/http3/settings.go b/internal/http3/settings.go
new file mode 100644
index 0000000000..b5e562ecad
--- /dev/null
+++ b/internal/http3/settings.go
@@ -0,0 +1,72 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "golang.org/x/net/internal/quic/quicwire"
+)
+
+const (
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1
+ settingsMaxFieldSectionSize = 0x06
+
+ // https://www.rfc-editor.org/rfc/rfc9204.html#section-5
+ settingsQPACKMaxTableCapacity = 0x01
+ settingsQPACKBlockedStreams = 0x07
+)
+
+// writeSettings writes a complete SETTINGS frame.
+// Its parameter is a list of alternating setting types and values.
+func (st *stream) writeSettings(settings ...int64) {
+ var size int64
+ for _, s := range settings {
+ // Settings values that don't fit in a QUIC varint ([0,2^62)) will panic here.
+ size += int64(quicwire.SizeVarint(uint64(s)))
+ }
+ st.writeVarint(int64(frameTypeSettings))
+ st.writeVarint(size)
+ for _, s := range settings {
+ st.writeVarint(s)
+ }
+}
+
+// readSettings reads a complete SETTINGS frame, including the frame header.
+func (st *stream) readSettings(f func(settingType, value int64) error) error {
+ frameType, err := st.readFrameHeader()
+ if err != nil || frameType != frameTypeSettings {
+ return &connectionError{
+ code: errH3MissingSettings,
+ message: "settings not sent on control stream",
+ }
+ }
+ for st.lim > 0 {
+ settingsType, err := st.readVarint()
+ if err != nil {
+ return err
+ }
+ settingsValue, err := st.readVarint()
+ if err != nil {
+ return err
+ }
+
+ // Use of HTTP/2 settings where there is no corresponding HTTP/3 setting
+ // is an error.
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1-5
+ switch settingsType {
+ case 0x02, 0x03, 0x04, 0x05:
+ return &connectionError{
+ code: errH3SettingsError,
+ message: "use of reserved setting",
+ }
+ }
+
+ if err := f(settingsType, settingsValue); err != nil {
+ return err
+ }
+ }
+ return st.endFrame()
+}
diff --git a/internal/http3/stream.go b/internal/http3/stream.go
new file mode 100644
index 0000000000..0f975407be
--- /dev/null
+++ b/internal/http3/stream.go
@@ -0,0 +1,262 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "context"
+ "io"
+
+ "golang.org/x/net/quic"
+)
+
+// A stream wraps a QUIC stream, providing methods to read/write various values.
+type stream struct {
+ stream *quic.Stream
+
+ // lim is the current read limit.
+ // Reading a frame header sets the limit to the end of the frame.
+ // Reading past the limit or reading less than the limit and ending the frame
+ // results in an error.
+ // -1 indicates no limit.
+ lim int64
+}
+
+// newConnStream creates a new stream on a connection.
+// It writes the stream header for unidirectional streams.
+//
+// The stream returned by newStream is not flushed,
+// and will not be sent to the peer until the caller calls
+// Flush or writes enough data to the stream.
+func newConnStream(ctx context.Context, qconn *quic.Conn, stype streamType) (*stream, error) {
+ var qs *quic.Stream
+ var err error
+ if stype == streamTypeRequest {
+ // Request streams are bidirectional.
+ qs, err = qconn.NewStream(ctx)
+ } else {
+ // All other streams are unidirectional.
+ qs, err = qconn.NewSendOnlyStream(ctx)
+ }
+ if err != nil {
+ return nil, err
+ }
+ st := &stream{
+ stream: qs,
+ lim: -1, // no limit
+ }
+ if stype != streamTypeRequest {
+ // Unidirectional stream header.
+ st.writeVarint(int64(stype))
+ }
+ return st, err
+}
+
+func newStream(qs *quic.Stream) *stream {
+ return &stream{
+ stream: qs,
+ lim: -1, // no limit
+ }
+}
+
+// readFrameHeader reads the type and length fields of an HTTP/3 frame.
+// It sets the read limit to the end of the frame.
+//
+// https://www.rfc-editor.org/rfc/rfc9114.html#section-7.1
+func (st *stream) readFrameHeader() (ftype frameType, err error) {
+ if st.lim >= 0 {
+ // We shoudn't call readFrameHeader before ending the previous frame.
+ return 0, errH3FrameError
+ }
+ ftype, err = readVarint[frameType](st)
+ if err != nil {
+ return 0, err
+ }
+ size, err := st.readVarint()
+ if err != nil {
+ return 0, err
+ }
+ st.lim = size
+ return ftype, nil
+}
+
+// endFrame is called after reading a frame to reset the read limit.
+// It returns an error if the entire contents of a frame have not been read.
+func (st *stream) endFrame() error {
+ if st.lim != 0 {
+ return &connectionError{
+ code: errH3FrameError,
+ message: "invalid HTTP/3 frame",
+ }
+ }
+ st.lim = -1
+ return nil
+}
+
+// readFrameData returns the remaining data in the current frame.
+func (st *stream) readFrameData() ([]byte, error) {
+ if st.lim < 0 {
+ return nil, errH3FrameError
+ }
+ // TODO: Pool buffers to avoid allocation here.
+ b := make([]byte, st.lim)
+ _, err := io.ReadFull(st, b)
+ if err != nil {
+ return nil, err
+ }
+ return b, nil
+}
+
+// ReadByte reads one byte from the stream.
+func (st *stream) ReadByte() (b byte, err error) {
+ if err := st.recordBytesRead(1); err != nil {
+ return 0, err
+ }
+ b, err = st.stream.ReadByte()
+ if err != nil {
+ if err == io.EOF && st.lim < 0 {
+ return 0, io.EOF
+ }
+ return 0, errH3FrameError
+ }
+ return b, nil
+}
+
+// Read reads from the stream.
+func (st *stream) Read(b []byte) (int, error) {
+ n, err := st.stream.Read(b)
+ if e2 := st.recordBytesRead(n); e2 != nil {
+ return 0, e2
+ }
+ if err == io.EOF {
+ if st.lim == 0 {
+ // EOF at end of frame, ignore.
+ return n, nil
+ } else if st.lim > 0 {
+ // EOF inside frame, error.
+ return 0, errH3FrameError
+ } else {
+ // EOF outside of frame, surface to caller.
+ return n, io.EOF
+ }
+ }
+ if err != nil {
+ return 0, errH3FrameError
+ }
+ return n, nil
+}
+
+// discardUnknownFrame discards an unknown frame.
+//
+// HTTP/3 requires that unknown frames be ignored on all streams.
+// However, a known frame appearing in an unexpected place is a fatal error,
+// so this returns an error if the frame is one we know.
+func (st *stream) discardUnknownFrame(ftype frameType) error {
+ switch ftype {
+ case frameTypeData,
+ frameTypeHeaders,
+ frameTypeCancelPush,
+ frameTypeSettings,
+ frameTypePushPromise,
+ frameTypeGoaway,
+ frameTypeMaxPushID:
+ return &connectionError{
+ code: errH3FrameUnexpected,
+ message: "unexpected " + ftype.String() + " frame",
+ }
+ }
+ return st.discardFrame()
+}
+
+// discardFrame discards any remaining data in the current frame and resets the read limit.
+func (st *stream) discardFrame() error {
+ // TODO: Consider adding a *quic.Stream method to discard some amount of data.
+ for range st.lim {
+ _, err := st.stream.ReadByte()
+ if err != nil {
+ return &streamError{errH3FrameError, err.Error()}
+ }
+ }
+ st.lim = -1
+ return nil
+}
+
+// Write writes to the stream.
+func (st *stream) Write(b []byte) (int, error) { return st.stream.Write(b) }
+
+// Flush commits data written to the stream.
+func (st *stream) Flush() error { return st.stream.Flush() }
+
+// readVarint reads a QUIC variable-length integer from the stream.
+func (st *stream) readVarint() (v int64, err error) {
+ b, err := st.stream.ReadByte()
+ if err != nil {
+ return 0, err
+ }
+ v = int64(b & 0x3f)
+ n := 1 << (b >> 6)
+ for i := 1; i < n; i++ {
+ b, err := st.stream.ReadByte()
+ if err != nil {
+ return 0, errH3FrameError
+ }
+ v = (v << 8) | int64(b)
+ }
+ if err := st.recordBytesRead(n); err != nil {
+ return 0, err
+ }
+ return v, nil
+}
+
+// readVarint reads a varint of a particular type.
+func readVarint[T ~int64 | ~uint64](st *stream) (T, error) {
+ v, err := st.readVarint()
+ return T(v), err
+}
+
+// writeVarint writes a QUIC variable-length integer to the stream.
+func (st *stream) writeVarint(v int64) {
+ switch {
+ case v <= (1<<6)-1:
+ st.stream.WriteByte(byte(v))
+ case v <= (1<<14)-1:
+ st.stream.WriteByte((1 << 6) | byte(v>>8))
+ st.stream.WriteByte(byte(v))
+ case v <= (1<<30)-1:
+ st.stream.WriteByte((2 << 6) | byte(v>>24))
+ st.stream.WriteByte(byte(v >> 16))
+ st.stream.WriteByte(byte(v >> 8))
+ st.stream.WriteByte(byte(v))
+ case v <= (1<<62)-1:
+ st.stream.WriteByte((3 << 6) | byte(v>>56))
+ st.stream.WriteByte(byte(v >> 48))
+ st.stream.WriteByte(byte(v >> 40))
+ st.stream.WriteByte(byte(v >> 32))
+ st.stream.WriteByte(byte(v >> 24))
+ st.stream.WriteByte(byte(v >> 16))
+ st.stream.WriteByte(byte(v >> 8))
+ st.stream.WriteByte(byte(v))
+ default:
+ panic("varint too large")
+ }
+}
+
+// recordBytesRead records that n bytes have been read.
+// It returns an error if the read passes the current limit.
+func (st *stream) recordBytesRead(n int) error {
+ if st.lim < 0 {
+ return nil
+ }
+ st.lim -= int64(n)
+ if st.lim < 0 {
+ st.stream = nil // panic if we try to read again
+ return &connectionError{
+ code: errH3FrameError,
+ message: "invalid HTTP/3 frame",
+ }
+ }
+ return nil
+}
diff --git a/internal/http3/stream_test.go b/internal/http3/stream_test.go
new file mode 100644
index 0000000000..12b281c558
--- /dev/null
+++ b/internal/http3/stream_test.go
@@ -0,0 +1,319 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "testing"
+
+ "golang.org/x/net/internal/quic/quicwire"
+)
+
+func TestStreamReadVarint(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, b := range [][]byte{
+ {0x00},
+ {0x3f},
+ {0x40, 0x00},
+ {0x7f, 0xff},
+ {0x80, 0x00, 0x00, 0x00},
+ {0xbf, 0xff, 0xff, 0xff},
+ {0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
+ // Example cases from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.1
+ {0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c},
+ {0x9d, 0x7f, 0x3e, 0x7d},
+ {0x7b, 0xbd},
+ {0x25},
+ {0x40, 0x25},
+ } {
+ trailer := []byte{0xde, 0xad, 0xbe, 0xef}
+ st1.Write(b)
+ st1.Write(trailer)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ got, err := st2.readVarint()
+ if err != nil {
+ t.Fatalf("st.readVarint() = %v", err)
+ }
+ want, _ := quicwire.ConsumeVarintInt64(b)
+ if got != want {
+ t.Fatalf("st.readVarint() = %v, want %v", got, want)
+ }
+ gotTrailer := make([]byte, len(trailer))
+ if _, err := io.ReadFull(st2, gotTrailer); err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(gotTrailer, trailer) {
+ t.Fatalf("after st.readVarint, read %x, want %x", gotTrailer, trailer)
+ }
+ }
+}
+
+func TestStreamWriteVarint(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, v := range []int64{
+ 0,
+ 63,
+ 16383,
+ 1073741823,
+ 4611686018427387903,
+ // Example cases from https://www.rfc-editor.org/rfc/rfc9000.html#section-a.1
+ 151288809941952652,
+ 494878333,
+ 15293,
+ 37,
+ } {
+ trailer := []byte{0xde, 0xad, 0xbe, 0xef}
+ st1.writeVarint(v)
+ st1.Write(trailer)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ want := quicwire.AppendVarint(nil, uint64(v))
+ want = append(want, trailer...)
+
+ got := make([]byte, len(want))
+ if _, err := io.ReadFull(st2, got); err != nil {
+ t.Fatal(err)
+ }
+
+ if !bytes.Equal(got, want) {
+ t.Errorf("AppendVarint(nil, %v) = %x, want %x", v, got, want)
+ }
+ }
+}
+
+func TestStreamReadFrames(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ for _, frame := range []struct {
+ ftype frameType
+ data []byte
+ }{{
+ ftype: 1,
+ data: []byte("hello"),
+ }, {
+ ftype: 2,
+ data: []byte{},
+ }, {
+ ftype: 3,
+ data: []byte("goodbye"),
+ }} {
+ st1.writeVarint(int64(frame.ftype))
+ st1.writeVarint(int64(len(frame.data)))
+ st1.Write(frame.data)
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if gotFrameType, err := st2.readFrameHeader(); err != nil || gotFrameType != frame.ftype {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", gotFrameType, err, frame.ftype)
+ }
+ if gotData, err := st2.readFrameData(); err != nil || !bytes.Equal(gotData, frame.data) {
+ t.Fatalf("st.readFrameData() = %x, %v; want %x, nil", gotData, err, frame.data)
+ }
+ if err := st2.endFrame(); err != nil {
+ t.Fatalf("st.endFrame() = %v; want nil", err)
+ }
+ }
+}
+
+func TestStreamReadFrameUnderflow(t *testing.T) {
+ const size = 4
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(0) // type
+ st1.writeVarint(size) // size
+ st1.Write(make([]byte, size)) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if _, err := io.ReadFull(st2, make([]byte, size-1)); err != nil {
+ t.Fatalf("st.Read() = %v", err)
+ }
+ // We have not consumed the full frame: Error.
+ if err := st2.endFrame(); !errors.Is(err, errH3FrameError) {
+ t.Fatalf("st.endFrame before end: %v, want errH3FrameError", err)
+ }
+}
+
+func TestStreamReadFrameWithoutEnd(t *testing.T) {
+ const size = 4
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(0) // type
+ st1.writeVarint(size) // size
+ st1.Write(make([]byte, size)) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if _, err := st2.readFrameHeader(); err == nil {
+ t.Fatalf("st.readFrameHeader before st.endFrame for prior frame: success, want error")
+ }
+}
+
+func TestStreamReadFrameOverflow(t *testing.T) {
+ const size = 4
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(0) // type
+ st1.writeVarint(size) // size
+ st1.Write(make([]byte, size+1)) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if _, err := io.ReadFull(st2, make([]byte, size+1)); !errors.Is(err, errH3FrameError) {
+ t.Fatalf("st.Read past end of frame: %v, want errH3FrameError", err)
+ }
+}
+
+func TestStreamReadFrameHeaderPartial(t *testing.T) {
+ var frame []byte
+ frame = quicwire.AppendVarint(frame, 1000) // type
+ frame = quicwire.AppendVarint(frame, 2000) // size
+
+ for i := 1; i < len(frame)-1; i++ {
+ st1, st2 := newStreamPair(t)
+ st1.Write(frame[:i])
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ st1.stream.CloseWrite()
+
+ if _, err := st2.readFrameHeader(); err == nil {
+ t.Fatalf("%v/%v bytes of frame available: st.readFrameHeader() succeded; want error", i, len(frame))
+ }
+ }
+}
+
+func TestStreamReadFrameDataPartial(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(1) // type
+ st1.writeVarint(100) // size
+ st1.Write(make([]byte, 50)) // data
+ st1.stream.CloseWrite()
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if n, err := io.ReadAll(st2); err == nil {
+ t.Fatalf("io.ReadAll with partial frame = %v, nil; want error", n)
+ }
+}
+
+func TestStreamReadByteFrameDataPartial(t *testing.T) {
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(1) // type
+ st1.writeVarint(100) // size
+ st1.stream.CloseWrite()
+ if _, err := st2.readFrameHeader(); err != nil {
+ t.Fatalf("st.readFrameHeader() = %v", err)
+ }
+ if b, err := st2.ReadByte(); err == nil {
+ t.Fatalf("io.ReadAll with partial frame = %v, nil; want error", b)
+ }
+}
+
+func TestStreamReadFrameDataAtEOF(t *testing.T) {
+ const typ = 10
+ data := []byte("hello")
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(typ) // type
+ st1.writeVarint(int64(len(data))) // size
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+ if got, err := st2.readFrameHeader(); err != nil || got != typ {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", got, err, typ)
+ }
+
+ st1.Write(data) // data
+ st1.stream.CloseWrite() // end stream
+ got := make([]byte, len(data)+1)
+ if n, err := st2.Read(got); err != nil || n != len(data) || !bytes.Equal(got[:n], data) {
+ t.Fatalf("st.Read() = %v, %v (data=%x); want %v, nil (data=%x)", n, err, got[:n], len(data), data)
+ }
+}
+
+func TestStreamReadFrameData(t *testing.T) {
+ const typ = 10
+ data := []byte("hello")
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(typ) // type
+ st1.writeVarint(int64(len(data))) // size
+ st1.Write(data) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if got, err := st2.readFrameHeader(); err != nil || got != typ {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", got, err, typ)
+ }
+ if got, err := st2.readFrameData(); err != nil || !bytes.Equal(got, data) {
+ t.Fatalf("st.readFrameData() = %x, %v; want %x, nil", got, err, data)
+ }
+}
+
+func TestStreamReadByte(t *testing.T) {
+ const stype = 1
+ const want = 42
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(stype) // stream type
+ st1.writeVarint(1) // size
+ st1.Write([]byte{want}) // data
+ if err := st1.Flush(); err != nil {
+ t.Fatal(err)
+ }
+
+ if got, err := st2.readFrameHeader(); err != nil || got != stype {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", got, err, stype)
+ }
+ if got, err := st2.ReadByte(); err != nil || got != want {
+ t.Fatalf("st.ReadByte() = %v, %v; want %v, nil", got, err, want)
+ }
+ if got, err := st2.ReadByte(); err == nil {
+ t.Fatalf("reading past end of frame: st.ReadByte() = %v, %v; want error", got, err)
+ }
+}
+
+func TestStreamDiscardFrame(t *testing.T) {
+ const typ = 10
+ data := []byte("hello")
+ st1, st2 := newStreamPair(t)
+ st1.writeVarint(typ) // type
+ st1.writeVarint(int64(len(data))) // size
+ st1.Write(data) // data
+ st1.stream.CloseWrite()
+
+ if got, err := st2.readFrameHeader(); err != nil || got != typ {
+ t.Fatalf("st.readFrameHeader() = %v, %v; want %v, nil", got, err, typ)
+ }
+ if err := st2.discardFrame(); err != nil {
+ t.Fatalf("st.discardFrame() = %v", err)
+ }
+ if b, err := io.ReadAll(st2); err != nil || len(b) > 0 {
+ t.Fatalf("after discarding frame, read %x, %v; want EOF", b, err)
+ }
+}
+
+func newStreamPair(t testing.TB) (s1, s2 *stream) {
+ t.Helper()
+ q1, q2 := newQUICStreamPair(t)
+ return newStream(q1), newStream(q2)
+}
diff --git a/internal/http3/transport.go b/internal/http3/transport.go
new file mode 100644
index 0000000000..b26524cbda
--- /dev/null
+++ b/internal/http3/transport.go
@@ -0,0 +1,190 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24
+
+package http3
+
+import (
+ "context"
+ "fmt"
+ "sync"
+
+ "golang.org/x/net/quic"
+)
+
+// A Transport is an HTTP/3 transport.
+//
+// It does not manage a pool of connections,
+// and therefore does not implement net/http.RoundTripper.
+//
+// TODO: Provide a way to register an HTTP/3 transport with a net/http.Transport's
+// connection pool.
+type Transport struct {
+ // Endpoint is the QUIC endpoint used by connections created by the transport.
+ // If unset, it is initialized by the first call to Dial.
+ Endpoint *quic.Endpoint
+
+ // Config is the QUIC configuration used for client connections.
+ // The Config may be nil.
+ //
+ // Dial may clone and modify the Config.
+ // The Config must not be modified after calling Dial.
+ Config *quic.Config
+
+ initOnce sync.Once
+ initErr error
+}
+
+func (tr *Transport) init() error {
+ tr.initOnce.Do(func() {
+ tr.Config = initConfig(tr.Config)
+ if tr.Endpoint == nil {
+ tr.Endpoint, tr.initErr = quic.Listen("udp", ":0", nil)
+ }
+ })
+ return tr.initErr
+}
+
+// Dial creates a new HTTP/3 client connection.
+func (tr *Transport) Dial(ctx context.Context, target string) (*ClientConn, error) {
+ if err := tr.init(); err != nil {
+ return nil, err
+ }
+ qconn, err := tr.Endpoint.Dial(ctx, "udp", target, tr.Config)
+ if err != nil {
+ return nil, err
+ }
+ return newClientConn(ctx, qconn)
+}
+
+// A ClientConn is a client HTTP/3 connection.
+//
+// Multiple goroutines may invoke methods on a ClientConn simultaneously.
+type ClientConn struct {
+ qconn *quic.Conn
+ genericConn
+
+ enc qpackEncoder
+ dec qpackDecoder
+}
+
+func newClientConn(ctx context.Context, qconn *quic.Conn) (*ClientConn, error) {
+ cc := &ClientConn{
+ qconn: qconn,
+ }
+ cc.enc.init()
+
+ // Create control stream and send SETTINGS frame.
+ controlStream, err := newConnStream(ctx, cc.qconn, streamTypeControl)
+ if err != nil {
+ return nil, fmt.Errorf("http3: cannot create control stream: %v", err)
+ }
+ controlStream.writeSettings()
+ controlStream.Flush()
+
+ go cc.acceptStreams(qconn, cc)
+ return cc, nil
+}
+
+// Close closes the connection.
+// Any in-flight requests are canceled.
+// Close does not wait for the peer to acknowledge the connection closing.
+func (cc *ClientConn) Close() error {
+ // Close the QUIC connection immediately with a status of NO_ERROR.
+ cc.qconn.Abort(nil)
+
+ // Return any existing error from the peer, but don't wait for it.
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ return cc.qconn.Wait(ctx)
+}
+
+func (cc *ClientConn) handleControlStream(st *stream) error {
+ // "A SETTINGS frame MUST be sent as the first frame of each control stream [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2
+ if err := st.readSettings(func(settingsType, settingsValue int64) error {
+ switch settingsType {
+ case settingsMaxFieldSectionSize:
+ _ = settingsValue // TODO
+ case settingsQPACKMaxTableCapacity:
+ _ = settingsValue // TODO
+ case settingsQPACKBlockedStreams:
+ _ = settingsValue // TODO
+ default:
+ // Unknown settings types are ignored.
+ }
+ return nil
+ }); err != nil {
+ return err
+ }
+
+ for {
+ ftype, err := st.readFrameHeader()
+ if err != nil {
+ return err
+ }
+ switch ftype {
+ case frameTypeCancelPush:
+ // "If a CANCEL_PUSH frame is received that references a push ID
+ // greater than currently allowed on the connection,
+ // this MUST be treated as a connection error of type H3_ID_ERROR."
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-7
+ return &connectionError{
+ code: errH3IDError,
+ message: "CANCEL_PUSH received when no MAX_PUSH_ID has been sent",
+ }
+ case frameTypeGoaway:
+ // TODO: Wait for requests to complete before closing connection.
+ return errH3NoError
+ default:
+ // Unknown frames are ignored.
+ if err := st.discardUnknownFrame(ftype); err != nil {
+ return err
+ }
+ }
+ }
+}
+
+func (cc *ClientConn) handleEncoderStream(*stream) error {
+ // TODO
+ return nil
+}
+
+func (cc *ClientConn) handleDecoderStream(*stream) error {
+ // TODO
+ return nil
+}
+
+func (cc *ClientConn) handlePushStream(*stream) error {
+ // "A client MUST treat receipt of a push stream as a connection error
+ // of type H3_ID_ERROR when no MAX_PUSH_ID frame has been sent [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-4.6-3
+ return &connectionError{
+ code: errH3IDError,
+ message: "push stream created when no MAX_PUSH_ID has been sent",
+ }
+}
+
+func (cc *ClientConn) handleRequestStream(st *stream) error {
+ // "Clients MUST treat receipt of a server-initiated bidirectional
+ // stream as a connection error of type H3_STREAM_CREATION_ERROR [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1-3
+ return &connectionError{
+ code: errH3StreamCreationError,
+ message: "server created bidirectional stream",
+ }
+}
+
+// abort closes the connection with an error.
+func (cc *ClientConn) abort(err error) {
+ if e, ok := err.(*connectionError); ok {
+ cc.qconn.Abort(&quic.ApplicationError{
+ Code: uint64(e.code),
+ Reason: e.message,
+ })
+ } else {
+ cc.qconn.Abort(err)
+ }
+}
diff --git a/internal/http3/transport_test.go b/internal/http3/transport_test.go
new file mode 100644
index 0000000000..b300866390
--- /dev/null
+++ b/internal/http3/transport_test.go
@@ -0,0 +1,448 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.24 && goexperiment.synctest
+
+package http3
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "reflect"
+ "slices"
+ "testing"
+ "testing/synctest"
+
+ "golang.org/x/net/internal/quic/quicwire"
+ "golang.org/x/net/quic"
+)
+
+func TestTransportServerCreatesBidirectionalStream(t *testing.T) {
+ // "Clients MUST treat receipt of a server-initiated bidirectional
+ // stream as a connection error of type H3_STREAM_CREATION_ERROR [...]"
+ // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1-3
+ runSynctest(t, func(t testing.TB) {
+ tc := newTestClientConn(t)
+ tc.greet()
+ st := tc.newStream(streamTypeRequest)
+ st.Flush()
+ tc.wantClosed("after server creates bidi stream", errH3StreamCreationError)
+ })
+}
+
+// A testQUICConn wraps a *quic.Conn and provides methods for inspecting it.
+type testQUICConn struct {
+ t testing.TB
+ qconn *quic.Conn
+ streams map[streamType][]*testQUICStream
+}
+
+func newTestQUICConn(t testing.TB, qconn *quic.Conn) *testQUICConn {
+ tq := &testQUICConn{
+ t: t,
+ qconn: qconn,
+ streams: make(map[streamType][]*testQUICStream),
+ }
+
+ go tq.acceptStreams(t.Context())
+
+ t.Cleanup(func() {
+ tq.qconn.Close()
+ })
+ return tq
+}
+
+func (tq *testQUICConn) acceptStreams(ctx context.Context) {
+ for {
+ qst, err := tq.qconn.AcceptStream(ctx)
+ if err != nil {
+ return
+ }
+ st := newStream(qst)
+ stype := streamTypeRequest
+ if qst.IsReadOnly() {
+ v, err := st.readVarint()
+ if err != nil {
+ tq.t.Errorf("error reading stream type from unidirectional stream: %v", err)
+ continue
+ }
+ stype = streamType(v)
+ }
+ tq.streams[stype] = append(tq.streams[stype], newTestQUICStream(tq.t, st))
+ }
+}
+
+func (tq *testQUICConn) newStream(stype streamType) *testQUICStream {
+ tq.t.Helper()
+ var qs *quic.Stream
+ var err error
+ if stype == streamTypeRequest {
+ qs, err = tq.qconn.NewStream(canceledCtx)
+ } else {
+ qs, err = tq.qconn.NewSendOnlyStream(canceledCtx)
+ }
+ if err != nil {
+ tq.t.Fatal(err)
+ }
+ st := newStream(qs)
+ if stype != streamTypeRequest {
+ st.writeVarint(int64(stype))
+ if err := st.Flush(); err != nil {
+ tq.t.Fatal(err)
+ }
+ }
+ return newTestQUICStream(tq.t, st)
+}
+
+// wantNotClosed asserts that the peer has not closed the connectioln.
+func (tq *testQUICConn) wantNotClosed(reason string) {
+ t := tq.t
+ t.Helper()
+ synctest.Wait()
+ err := tq.qconn.Wait(canceledCtx)
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("%v: want QUIC connection to be alive; closed with error: %v", reason, err)
+ }
+}
+
+// wantClosed asserts that the peer has closed the connection
+// with the provided error code.
+func (tq *testQUICConn) wantClosed(reason string, want error) {
+ t := tq.t
+ t.Helper()
+ synctest.Wait()
+
+ if e, ok := want.(http3Error); ok {
+ want = &quic.ApplicationError{Code: uint64(e)}
+ }
+ got := tq.qconn.Wait(canceledCtx)
+ if errors.Is(got, context.Canceled) {
+ t.Fatalf("%v: want QUIC connection closed, but it is not", reason)
+ }
+ if !errors.Is(got, want) {
+ t.Fatalf("%v: connection closed with error: %v; want %v", reason, got, want)
+ }
+}
+
+// wantStream asserts that a stream of a given type has been created,
+// and returns that stream.
+func (tq *testQUICConn) wantStream(stype streamType) *testQUICStream {
+ tq.t.Helper()
+ synctest.Wait()
+ if len(tq.streams[stype]) == 0 {
+ tq.t.Fatalf("expected a %v stream to be created, but none were", stype)
+ }
+ ts := tq.streams[stype][0]
+ tq.streams[stype] = tq.streams[stype][1:]
+ return ts
+}
+
+// testQUICStream wraps a QUIC stream and provides methods for inspecting it.
+type testQUICStream struct {
+ t testing.TB
+ *stream
+}
+
+func newTestQUICStream(t testing.TB, st *stream) *testQUICStream {
+ st.stream.SetReadContext(canceledCtx)
+ st.stream.SetWriteContext(canceledCtx)
+ return &testQUICStream{
+ t: t,
+ stream: st,
+ }
+}
+
+// wantFrameHeader calls readFrameHeader and asserts that the frame is of a given type.
+func (ts *testQUICStream) wantFrameHeader(reason string, wantType frameType) {
+ ts.t.Helper()
+ synctest.Wait()
+ gotType, err := ts.readFrameHeader()
+ if err != nil {
+ ts.t.Fatalf("%v: failed to read frame header: %v", reason, err)
+ }
+ if gotType != wantType {
+ ts.t.Fatalf("%v: got frame type %v, want %v", reason, gotType, wantType)
+ }
+}
+
+// wantHeaders reads a HEADERS frame.
+// If want is nil, the contents of the frame are ignored.
+func (ts *testQUICStream) wantHeaders(want http.Header) {
+ ts.t.Helper()
+ ftype, err := ts.readFrameHeader()
+ if err != nil {
+ ts.t.Fatalf("want HEADERS frame, got error: %v", err)
+ }
+ if ftype != frameTypeHeaders {
+ ts.t.Fatalf("want HEADERS frame, got: %v", ftype)
+ }
+
+ if want == nil {
+ if err := ts.discardFrame(); err != nil {
+ ts.t.Fatalf("discardFrame: %v", err)
+ }
+ return
+ }
+
+ got := make(http.Header)
+ var dec qpackDecoder
+ err = dec.decode(ts.stream, func(_ indexType, name, value string) error {
+ got.Add(name, value)
+ return nil
+ })
+ if diff := diffHeaders(got, want); diff != "" {
+ ts.t.Fatalf("unexpected response headers:\n%v", diff)
+ }
+ if err := ts.endFrame(); err != nil {
+ ts.t.Fatalf("endFrame: %v", err)
+ }
+}
+
+func (ts *testQUICStream) encodeHeaders(h http.Header) []byte {
+ ts.t.Helper()
+ var enc qpackEncoder
+ return enc.encode(func(yield func(itype indexType, name, value string)) {
+ names := slices.Collect(maps.Keys(h))
+ slices.Sort(names)
+ for _, k := range names {
+ for _, v := range h[k] {
+ yield(mayIndex, k, v)
+ }
+ }
+ })
+}
+
+func (ts *testQUICStream) writeHeaders(h http.Header) {
+ ts.t.Helper()
+ headers := ts.encodeHeaders(h)
+ ts.writeVarint(int64(frameTypeHeaders))
+ ts.writeVarint(int64(len(headers)))
+ ts.Write(headers)
+ if err := ts.Flush(); err != nil {
+ ts.t.Fatalf("flushing HEADERS frame: %v", err)
+ }
+}
+
+func (ts *testQUICStream) wantData(want []byte) {
+ ts.t.Helper()
+ synctest.Wait()
+ ftype, err := ts.readFrameHeader()
+ if err != nil {
+ ts.t.Fatalf("want DATA frame, got error: %v", err)
+ }
+ if ftype != frameTypeData {
+ ts.t.Fatalf("want DATA frame, got: %v", ftype)
+ }
+ got, err := ts.readFrameData()
+ if err != nil {
+ ts.t.Fatalf("error reading DATA frame: %v", err)
+ }
+ if !bytes.Equal(got, want) {
+ ts.t.Fatalf("got data: {%x}, want {%x}", got, want)
+ }
+ if err := ts.endFrame(); err != nil {
+ ts.t.Fatalf("endFrame: %v", err)
+ }
+}
+
+func (ts *testQUICStream) wantClosed(reason string) {
+ ts.t.Helper()
+ synctest.Wait()
+ ftype, err := ts.readFrameHeader()
+ if err != io.EOF {
+ ts.t.Fatalf("%v: want io.EOF, got %v %v", reason, ftype, err)
+ }
+}
+
+func (ts *testQUICStream) wantError(want quic.StreamErrorCode) {
+ ts.t.Helper()
+ synctest.Wait()
+ _, err := ts.stream.stream.ReadByte()
+ if err == nil {
+ ts.t.Fatalf("successfully read from stream; want stream error code %v", want)
+ }
+ var got quic.StreamErrorCode
+ if !errors.As(err, &got) {
+ ts.t.Fatalf("stream error = %v; want %v", err, want)
+ }
+ if got != want {
+ ts.t.Fatalf("stream error code = %v; want %v", got, want)
+ }
+}
+
+func (ts *testQUICStream) writePushPromise(pushID int64, h http.Header) {
+ ts.t.Helper()
+ headers := ts.encodeHeaders(h)
+ ts.writeVarint(int64(frameTypePushPromise))
+ ts.writeVarint(int64(quicwire.SizeVarint(uint64(pushID)) + len(headers)))
+ ts.writeVarint(pushID)
+ ts.Write(headers)
+ if err := ts.Flush(); err != nil {
+ ts.t.Fatalf("flushing PUSH_PROMISE frame: %v", err)
+ }
+}
+
+func diffHeaders(got, want http.Header) string {
+ // nil and 0-length non-nil are equal.
+ if len(got) == 0 && len(want) == 0 {
+ return ""
+ }
+ // We could do a more sophisticated diff here.
+ // DeepEqual is good enough for now.
+ if reflect.DeepEqual(got, want) {
+ return ""
+ }
+ return fmt.Sprintf("got: %v\nwant: %v", got, want)
+}
+
+func (ts *testQUICStream) Flush() error {
+ err := ts.stream.Flush()
+ ts.t.Helper()
+ if err != nil {
+ ts.t.Errorf("unexpected error flushing stream: %v", err)
+ }
+ return err
+}
+
+// A testClientConn is a ClientConn on a test network.
+type testClientConn struct {
+ tr *Transport
+ cc *ClientConn
+
+ // *testQUICConn is the server half of the connection.
+ *testQUICConn
+ control *testQUICStream
+}
+
+func newTestClientConn(t testing.TB) *testClientConn {
+ e1, e2 := newQUICEndpointPair(t)
+ tr := &Transport{
+ Endpoint: e1,
+ Config: &quic.Config{
+ TLSConfig: testTLSConfig,
+ },
+ }
+
+ cc, err := tr.Dial(t.Context(), e2.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() {
+ cc.Close()
+ })
+ srvConn, err := e2.Accept(t.Context())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tc := &testClientConn{
+ tr: tr,
+ cc: cc,
+ testQUICConn: newTestQUICConn(t, srvConn),
+ }
+ synctest.Wait()
+ return tc
+}
+
+// greet performs initial connection handshaking with the client.
+func (tc *testClientConn) greet() {
+ // Client creates a control stream.
+ clientControlStream := tc.wantStream(streamTypeControl)
+ clientControlStream.wantFrameHeader(
+ "client sends SETTINGS frame on control stream",
+ frameTypeSettings)
+ clientControlStream.discardFrame()
+
+ // Server creates a control stream.
+ tc.control = tc.newStream(streamTypeControl)
+ tc.control.writeVarint(int64(frameTypeSettings))
+ tc.control.writeVarint(0) // size
+ tc.control.Flush()
+
+ synctest.Wait()
+}
+
+type testRoundTrip struct {
+ t testing.TB
+ resp *http.Response
+ respErr error
+}
+
+func (rt *testRoundTrip) done() bool {
+ synctest.Wait()
+ return rt.resp != nil || rt.respErr != nil
+}
+
+func (rt *testRoundTrip) result() (*http.Response, error) {
+ rt.t.Helper()
+ if !rt.done() {
+ rt.t.Fatal("RoundTrip is not done; want it to be")
+ }
+ return rt.resp, rt.respErr
+}
+
+func (rt *testRoundTrip) response() *http.Response {
+ rt.t.Helper()
+ if !rt.done() {
+ rt.t.Fatal("RoundTrip is not done; want it to be")
+ }
+ if rt.respErr != nil {
+ rt.t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
+ }
+ return rt.resp
+}
+
+// err returns the (possibly nil) error result of RoundTrip.
+func (rt *testRoundTrip) err() error {
+ rt.t.Helper()
+ _, err := rt.result()
+ return err
+}
+
+func (rt *testRoundTrip) wantError(reason string) {
+ rt.t.Helper()
+ synctest.Wait()
+ if !rt.done() {
+ rt.t.Fatalf("%v: RoundTrip is not done; want it to have returned an error", reason)
+ }
+ if rt.respErr == nil {
+ rt.t.Fatalf("%v: RoundTrip succeeded; want it to have returned an error", reason)
+ }
+}
+
+// wantStatus indicates the expected response StatusCode.
+func (rt *testRoundTrip) wantStatus(want int) {
+ rt.t.Helper()
+ if got := rt.response().StatusCode; got != want {
+ rt.t.Fatalf("got response status %v, want %v", got, want)
+ }
+}
+
+func (rt *testRoundTrip) wantHeaders(want http.Header) {
+ rt.t.Helper()
+ if diff := diffHeaders(rt.response().Header, want); diff != "" {
+ rt.t.Fatalf("unexpected response headers:\n%v", diff)
+ }
+}
+
+func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
+ rt := &testRoundTrip{t: tc.t}
+ go func() {
+ rt.resp, rt.respErr = tc.cc.RoundTrip(req)
+ }()
+ return rt
+}
+
+// canceledCtx is a canceled Context.
+// Used for performing non-blocking QUIC operations.
+var canceledCtx = func() context.Context {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ return ctx
+}()
diff --git a/internal/httpcommon/ascii.go b/internal/httpcommon/ascii.go
new file mode 100644
index 0000000000..ed14da5afc
--- /dev/null
+++ b/internal/httpcommon/ascii.go
@@ -0,0 +1,53 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httpcommon
+
+import "strings"
+
+// The HTTP protocols are defined in terms of ASCII, not Unicode. This file
+// contains helper functions which may use Unicode-aware functions which would
+// otherwise be unsafe and could introduce vulnerabilities if used improperly.
+
+// asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
+// are equal, ASCII-case-insensitively.
+func asciiEqualFold(s, t string) bool {
+ if len(s) != len(t) {
+ return false
+ }
+ for i := 0; i < len(s); i++ {
+ if lower(s[i]) != lower(t[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+// lower returns the ASCII lowercase version of b.
+func lower(b byte) byte {
+ if 'A' <= b && b <= 'Z' {
+ return b + ('a' - 'A')
+ }
+ return b
+}
+
+// isASCIIPrint returns whether s is ASCII and printable according to
+// https://tools.ietf.org/html/rfc20#section-4.2.
+func isASCIIPrint(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] < ' ' || s[i] > '~' {
+ return false
+ }
+ }
+ return true
+}
+
+// asciiToLower returns the lowercase version of s if s is ASCII and printable,
+// and whether or not it was.
+func asciiToLower(s string) (lower string, ok bool) {
+ if !isASCIIPrint(s) {
+ return "", false
+ }
+ return strings.ToLower(s), true
+}
diff --git a/http2/headermap.go b/internal/httpcommon/headermap.go
similarity index 74%
rename from http2/headermap.go
rename to internal/httpcommon/headermap.go
index 149b3dd20e..92483d8e41 100644
--- a/http2/headermap.go
+++ b/internal/httpcommon/headermap.go
@@ -1,11 +1,11 @@
-// Copyright 2014 The Go Authors. All rights reserved.
+// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-package http2
+package httpcommon
import (
- "net/http"
+ "net/textproto"
"sync"
)
@@ -82,13 +82,15 @@ func buildCommonHeaderMaps() {
commonLowerHeader = make(map[string]string, len(common))
commonCanonHeader = make(map[string]string, len(common))
for _, v := range common {
- chk := http.CanonicalHeaderKey(v)
+ chk := textproto.CanonicalMIMEHeaderKey(v)
commonLowerHeader[chk] = v
commonCanonHeader[v] = chk
}
}
-func lowerHeader(v string) (lower string, ascii bool) {
+// LowerHeader returns the lowercase form of a header name,
+// used on the wire for HTTP/2 and HTTP/3 requests.
+func LowerHeader(v string) (lower string, ascii bool) {
buildCommonHeaderMapsOnce()
if s, ok := commonLowerHeader[v]; ok {
return s, true
@@ -96,10 +98,18 @@ func lowerHeader(v string) (lower string, ascii bool) {
return asciiToLower(v)
}
-func canonicalHeader(v string) string {
+// CanonicalHeader canonicalizes a header name. (For example, "host" becomes "Host".)
+func CanonicalHeader(v string) string {
buildCommonHeaderMapsOnce()
if s, ok := commonCanonHeader[v]; ok {
return s
}
- return http.CanonicalHeaderKey(v)
+ return textproto.CanonicalMIMEHeaderKey(v)
+}
+
+// CachedCanonicalHeader returns the canonical form of a well-known header name.
+func CachedCanonicalHeader(v string) (string, bool) {
+ buildCommonHeaderMapsOnce()
+ s, ok := commonCanonHeader[v]
+ return s, ok
}
diff --git a/internal/httpcommon/httpcommon_test.go b/internal/httpcommon/httpcommon_test.go
new file mode 100644
index 0000000000..e725ec76cb
--- /dev/null
+++ b/internal/httpcommon/httpcommon_test.go
@@ -0,0 +1,37 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httpcommon_test
+
+import (
+ "bytes"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+// This package is imported by the net/http package,
+// and therefore must not itself import net/http.
+func TestNoNetHttp(t *testing.T) {
+ files, err := filepath.Glob("*.go")
+ if err != nil {
+ t.Fatal(err)
+ }
+ for _, file := range files {
+ if strings.HasSuffix(file, "_test.go") {
+ continue
+ }
+ // Could use something complex like go/build or x/tools/go/packages,
+ // but there's no reason for "net/http" to appear (in quotes) in the source
+ // otherwise, so just use a simple substring search.
+ data, err := os.ReadFile(file)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if bytes.Contains(data, []byte(`"net/http"`)) {
+ t.Errorf(`%s: cannot import "net/http"`, file)
+ }
+ }
+}
diff --git a/internal/httpcommon/request.go b/internal/httpcommon/request.go
new file mode 100644
index 0000000000..4b70553179
--- /dev/null
+++ b/internal/httpcommon/request.go
@@ -0,0 +1,467 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httpcommon
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http/httptrace"
+ "net/textproto"
+ "net/url"
+ "sort"
+ "strconv"
+ "strings"
+
+ "golang.org/x/net/http/httpguts"
+ "golang.org/x/net/http2/hpack"
+)
+
+var (
+ ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit")
+)
+
+// Request is a subset of http.Request.
+// It'd be simpler to pass an *http.Request, of course, but we can't depend on net/http
+// without creating a dependency cycle.
+type Request struct {
+ URL *url.URL
+ Method string
+ Host string
+ Header map[string][]string
+ Trailer map[string][]string
+ ActualContentLength int64 // 0 means 0, -1 means unknown
+}
+
+// EncodeHeadersParam is parameters to EncodeHeaders.
+type EncodeHeadersParam struct {
+ Request Request
+
+ // AddGzipHeader indicates that an "accept-encoding: gzip" header should be
+ // added to the request.
+ AddGzipHeader bool
+
+ // PeerMaxHeaderListSize, when non-zero, is the peer's MAX_HEADER_LIST_SIZE setting.
+ PeerMaxHeaderListSize uint64
+
+ // DefaultUserAgent is the User-Agent header to send when the request
+ // neither contains a User-Agent nor disables it.
+ DefaultUserAgent string
+}
+
+// EncodeHeadersParam is the result of EncodeHeaders.
+type EncodeHeadersResult struct {
+ HasBody bool
+ HasTrailers bool
+}
+
+// EncodeHeaders constructs request headers common to HTTP/2 and HTTP/3.
+// It validates a request and calls headerf with each pseudo-header and header
+// for the request.
+// The headerf function is called with the validated, canonicalized header name.
+func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) {
+ req := param.Request
+
+ // Check for invalid connection-level headers.
+ if err := checkConnHeaders(req.Header); err != nil {
+ return res, err
+ }
+
+ if req.URL == nil {
+ return res, errors.New("Request.URL is nil")
+ }
+
+ host := req.Host
+ if host == "" {
+ host = req.URL.Host
+ }
+ host, err := httpguts.PunycodeHostPort(host)
+ if err != nil {
+ return res, err
+ }
+ if !httpguts.ValidHostHeader(host) {
+ return res, errors.New("invalid Host header")
+ }
+
+ // isNormalConnect is true if this is a non-extended CONNECT request.
+ isNormalConnect := false
+ var protocol string
+ if vv := req.Header[":protocol"]; len(vv) > 0 {
+ protocol = vv[0]
+ }
+ if req.Method == "CONNECT" && protocol == "" {
+ isNormalConnect = true
+ } else if protocol != "" && req.Method != "CONNECT" {
+ return res, errors.New("invalid :protocol header in non-CONNECT request")
+ }
+
+ // Validate the path, except for non-extended CONNECT requests which have no path.
+ var path string
+ if !isNormalConnect {
+ path = req.URL.RequestURI()
+ if !validPseudoPath(path) {
+ orig := path
+ path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
+ if !validPseudoPath(path) {
+ if req.URL.Opaque != "" {
+ return res, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
+ } else {
+ return res, fmt.Errorf("invalid request :path %q", orig)
+ }
+ }
+ }
+ }
+
+ // Check for any invalid headers+trailers and return an error before we
+ // potentially pollute our hpack state. (We want to be able to
+ // continue to reuse the hpack encoder for future requests)
+ if err := validateHeaders(req.Header); err != "" {
+ return res, fmt.Errorf("invalid HTTP header %s", err)
+ }
+ if err := validateHeaders(req.Trailer); err != "" {
+ return res, fmt.Errorf("invalid HTTP trailer %s", err)
+ }
+
+ trailers, err := commaSeparatedTrailers(req.Trailer)
+ if err != nil {
+ return res, err
+ }
+
+ enumerateHeaders := func(f func(name, value string)) {
+ // 8.1.2.3 Request Pseudo-Header Fields
+ // The :path pseudo-header field includes the path and query parts of the
+ // target URI (the path-absolute production and optionally a '?' character
+ // followed by the query production, see Sections 3.3 and 3.4 of
+ // [RFC3986]).
+ f(":authority", host)
+ m := req.Method
+ if m == "" {
+ m = "GET"
+ }
+ f(":method", m)
+ if !isNormalConnect {
+ f(":path", path)
+ f(":scheme", req.URL.Scheme)
+ }
+ if protocol != "" {
+ f(":protocol", protocol)
+ }
+ if trailers != "" {
+ f("trailer", trailers)
+ }
+
+ var didUA bool
+ for k, vv := range req.Header {
+ if asciiEqualFold(k, "host") || asciiEqualFold(k, "content-length") {
+ // Host is :authority, already sent.
+ // Content-Length is automatic, set below.
+ continue
+ } else if asciiEqualFold(k, "connection") ||
+ asciiEqualFold(k, "proxy-connection") ||
+ asciiEqualFold(k, "transfer-encoding") ||
+ asciiEqualFold(k, "upgrade") ||
+ asciiEqualFold(k, "keep-alive") {
+ // Per 8.1.2.2 Connection-Specific Header
+ // Fields, don't send connection-specific
+ // fields. We have already checked if any
+ // are error-worthy so just ignore the rest.
+ continue
+ } else if asciiEqualFold(k, "user-agent") {
+ // Match Go's http1 behavior: at most one
+ // User-Agent. If set to nil or empty string,
+ // then omit it. Otherwise if not mentioned,
+ // include the default (below).
+ didUA = true
+ if len(vv) < 1 {
+ continue
+ }
+ vv = vv[:1]
+ if vv[0] == "" {
+ continue
+ }
+ } else if asciiEqualFold(k, "cookie") {
+ // Per 8.1.2.5 To allow for better compression efficiency, the
+ // Cookie header field MAY be split into separate header fields,
+ // each with one or more cookie-pairs.
+ for _, v := range vv {
+ for {
+ p := strings.IndexByte(v, ';')
+ if p < 0 {
+ break
+ }
+ f("cookie", v[:p])
+ p++
+ // strip space after semicolon if any.
+ for p+1 <= len(v) && v[p] == ' ' {
+ p++
+ }
+ v = v[p:]
+ }
+ if len(v) > 0 {
+ f("cookie", v)
+ }
+ }
+ continue
+ } else if k == ":protocol" {
+ // :protocol pseudo-header was already sent above.
+ continue
+ }
+
+ for _, v := range vv {
+ f(k, v)
+ }
+ }
+ if shouldSendReqContentLength(req.Method, req.ActualContentLength) {
+ f("content-length", strconv.FormatInt(req.ActualContentLength, 10))
+ }
+ if param.AddGzipHeader {
+ f("accept-encoding", "gzip")
+ }
+ if !didUA {
+ f("user-agent", param.DefaultUserAgent)
+ }
+ }
+
+ // Do a first pass over the headers counting bytes to ensure
+ // we don't exceed cc.peerMaxHeaderListSize. This is done as a
+ // separate pass before encoding the headers to prevent
+ // modifying the hpack state.
+ if param.PeerMaxHeaderListSize > 0 {
+ hlSize := uint64(0)
+ enumerateHeaders(func(name, value string) {
+ hf := hpack.HeaderField{Name: name, Value: value}
+ hlSize += uint64(hf.Size())
+ })
+
+ if hlSize > param.PeerMaxHeaderListSize {
+ return res, ErrRequestHeaderListSize
+ }
+ }
+
+ trace := httptrace.ContextClientTrace(ctx)
+
+ // Header list size is ok. Write the headers.
+ enumerateHeaders(func(name, value string) {
+ name, ascii := LowerHeader(name)
+ if !ascii {
+ // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
+ // field names have to be ASCII characters (just as in HTTP/1.x).
+ return
+ }
+
+ headerf(name, value)
+
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField(name, []string{value})
+ }
+ })
+
+ res.HasBody = req.ActualContentLength != 0
+ res.HasTrailers = trailers != ""
+ return res, nil
+}
+
+// IsRequestGzip reports whether we should add an Accept-Encoding: gzip header
+// for a request.
+func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool {
+ // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
+ if !disableCompression &&
+ len(header["Accept-Encoding"]) == 0 &&
+ len(header["Range"]) == 0 &&
+ method != "HEAD" {
+ // Request gzip only, not deflate. Deflate is ambiguous and
+ // not as universally supported anyway.
+ // See: https://zlib.net/zlib_faq.html#faq39
+ //
+ // Note that we don't request this for HEAD requests,
+ // due to a bug in nginx:
+ // http://trac.nginx.org/nginx/ticket/358
+ // https://golang.org/issue/5522
+ //
+ // We don't request gzip if the request is for a range, since
+ // auto-decoding a portion of a gzipped document will just fail
+ // anyway. See https://golang.org/issue/8923
+ return true
+ }
+ return false
+}
+
+// checkConnHeaders checks whether req has any invalid connection-level headers.
+//
+// https://www.rfc-editor.org/rfc/rfc9114.html#section-4.2-3
+// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.2.2-1
+//
+// Certain headers are special-cased as okay but not transmitted later.
+// For example, we allow "Transfer-Encoding: chunked", but drop the header when encoding.
+func checkConnHeaders(h map[string][]string) error {
+ if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") {
+ return fmt.Errorf("invalid Upgrade request header: %q", vv)
+ }
+ if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
+ return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv)
+ }
+ if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) {
+ return fmt.Errorf("invalid Connection request header: %q", vv)
+ }
+ return nil
+}
+
+func commaSeparatedTrailers(trailer map[string][]string) (string, error) {
+ keys := make([]string, 0, len(trailer))
+ for k := range trailer {
+ k = CanonicalHeader(k)
+ switch k {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return "", fmt.Errorf("invalid Trailer key %q", k)
+ }
+ keys = append(keys, k)
+ }
+ if len(keys) > 0 {
+ sort.Strings(keys)
+ return strings.Join(keys, ","), nil
+ }
+ return "", nil
+}
+
+// validPseudoPath reports whether v is a valid :path pseudo-header
+// value. It must be either:
+//
+// - a non-empty string starting with '/'
+// - the string '*', for OPTIONS requests.
+//
+// For now this is only used a quick check for deciding when to clean
+// up Opaque URLs before sending requests from the Transport.
+// See golang.org/issue/16847
+//
+// We used to enforce that the path also didn't start with "//", but
+// Google's GFE accepts such paths and Chrome sends them, so ignore
+// that part of the spec. See golang.org/issue/19103.
+func validPseudoPath(v string) bool {
+ return (len(v) > 0 && v[0] == '/') || v == "*"
+}
+
+func validateHeaders(hdrs map[string][]string) string {
+ for k, vv := range hdrs {
+ if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" {
+ return fmt.Sprintf("name %q", k)
+ }
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ // Don't include the value in the error,
+ // because it may be sensitive.
+ return fmt.Sprintf("value for header %q", k)
+ }
+ }
+ }
+ return ""
+}
+
+// shouldSendReqContentLength reports whether we should send
+// a "content-length" request header. This logic is basically a copy of the net/http
+// transferWriter.shouldSendContentLength.
+// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
+// -1 means unknown.
+func shouldSendReqContentLength(method string, contentLength int64) bool {
+ if contentLength > 0 {
+ return true
+ }
+ if contentLength < 0 {
+ return false
+ }
+ // For zero bodies, whether we send a content-length depends on the method.
+ // It also kinda doesn't matter for http2 either way, with END_STREAM.
+ switch method {
+ case "POST", "PUT", "PATCH":
+ return true
+ default:
+ return false
+ }
+}
+
+// ServerRequestParam is parameters to NewServerRequest.
+type ServerRequestParam struct {
+ Method string
+ Scheme, Authority, Path string
+ Protocol string
+ Header map[string][]string
+}
+
+// ServerRequestResult is the result of NewServerRequest.
+type ServerRequestResult struct {
+ // Various http.Request fields.
+ URL *url.URL
+ RequestURI string
+ Trailer map[string][]string
+
+ NeedsContinue bool // client provided an "Expect: 100-continue" header
+
+ // If the request should be rejected, this is a short string suitable for passing
+ // to the http2 package's CountError function.
+ // It might be a bit odd to return errors this way rather than returing an error,
+ // but this ensures we don't forget to include a CountError reason.
+ InvalidReason string
+}
+
+func NewServerRequest(rp ServerRequestParam) ServerRequestResult {
+ needsContinue := httpguts.HeaderValuesContainsToken(rp.Header["Expect"], "100-continue")
+ if needsContinue {
+ delete(rp.Header, "Expect")
+ }
+ // Merge Cookie headers into one "; "-delimited value.
+ if cookies := rp.Header["Cookie"]; len(cookies) > 1 {
+ rp.Header["Cookie"] = []string{strings.Join(cookies, "; ")}
+ }
+
+ // Setup Trailers
+ var trailer map[string][]string
+ for _, v := range rp.Header["Trailer"] {
+ for _, key := range strings.Split(v, ",") {
+ key = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(key))
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ // Bogus. (copy of http1 rules)
+ // Ignore.
+ default:
+ if trailer == nil {
+ trailer = make(map[string][]string)
+ }
+ trailer[key] = nil
+ }
+ }
+ }
+ delete(rp.Header, "Trailer")
+
+ // "':authority' MUST NOT include the deprecated userinfo subcomponent
+ // for "http" or "https" schemed URIs."
+ // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8
+ if strings.IndexByte(rp.Authority, '@') != -1 && (rp.Scheme == "http" || rp.Scheme == "https") {
+ return ServerRequestResult{
+ InvalidReason: "userinfo_in_authority",
+ }
+ }
+
+ var url_ *url.URL
+ var requestURI string
+ if rp.Method == "CONNECT" && rp.Protocol == "" {
+ url_ = &url.URL{Host: rp.Authority}
+ requestURI = rp.Authority // mimic HTTP/1 server behavior
+ } else {
+ var err error
+ url_, err = url.ParseRequestURI(rp.Path)
+ if err != nil {
+ return ServerRequestResult{
+ InvalidReason: "bad_path",
+ }
+ }
+ requestURI = rp.Path
+ }
+
+ return ServerRequestResult{
+ URL: url_,
+ NeedsContinue: needsContinue,
+ RequestURI: requestURI,
+ Trailer: trailer,
+ }
+}
diff --git a/internal/httpcommon/request_test.go b/internal/httpcommon/request_test.go
new file mode 100644
index 0000000000..b8792977c1
--- /dev/null
+++ b/internal/httpcommon/request_test.go
@@ -0,0 +1,672 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httpcommon
+
+import (
+ "cmp"
+ "context"
+ "io"
+ "net/http"
+ "slices"
+ "strings"
+ "testing"
+)
+
+func TestEncodeHeaders(t *testing.T) {
+ type header struct {
+ name string
+ value string
+ }
+ for _, test := range []struct {
+ name string
+ in EncodeHeadersParam
+ want EncodeHeadersResult
+ wantHeaders []header
+ disableCompression bool
+ }{{
+ name: "simple request",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("GET", "https://example.tld/", nil))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "host set from URL",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Host = ""
+ req.URL.Host = "example.tld"
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "chunked transfer-encoding",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Transfer-Encoding", "chunked") // ignored
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "connection close",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Connection", "close")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "connection keep-alive",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Connection", "keep-alive")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "normal connect",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("CONNECT", "https://example.tld/", nil))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "CONNECT"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "extended connect",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("CONNECT", "https://example.tld/", nil))
+ req.Header.Set(":protocol", "foo")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "CONNECT"},
+ {":path", "/"},
+ {":protocol", "foo"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "trailers",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("a", "1")
+ req.Trailer.Set("b", "2")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: true,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"trailer", "A,B"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "override user-agent",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("User-Agent", "GopherTron 9000")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "GopherTron 9000"},
+ },
+ }, {
+ name: "disable user-agent",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header["User-Agent"] = nil
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ },
+ }, {
+ name: "ignore host header",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Host", "gophers.tld/") // ignored
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "crumble cookie header",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Cookie", "a=b; b=c; c=d")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ // Cookie header is split into separate header fields.
+ {"cookie", "a=b"},
+ {"cookie", "b=c"},
+ {"cookie", "c=d"},
+ },
+ }, {
+ name: "post with nil body",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("POST", "https://example.tld/", nil))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "POST"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ {"content-length", "0"},
+ },
+ }, {
+ name: "post with NoBody",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("POST", "https://example.tld/", http.NoBody))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "POST"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ {"content-length", "0"},
+ },
+ }, {
+ name: "post with Content-Length",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ type reader struct{ io.ReadCloser }
+ req := must(http.NewRequest("POST", "https://example.tld/", reader{}))
+ req.ContentLength = 10
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: true,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "POST"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ {"content-length", "10"},
+ },
+ }, {
+ name: "post with unknown Content-Length",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ type reader struct{ io.ReadCloser }
+ req := must(http.NewRequest("POST", "https://example.tld/", reader{}))
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: true,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "POST"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "gzip"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "explicit accept-encoding",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Accept-Encoding", "deflate")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "GET"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"accept-encoding", "deflate"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "head request",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ return must(http.NewRequest("HEAD", "https://example.tld/", nil))
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "HEAD"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"user-agent", "default-user-agent"},
+ },
+ }, {
+ name: "range request",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("HEAD", "https://example.tld/", nil))
+ req.Header.Set("Range", "bytes=0-10")
+ return req
+ }),
+ DefaultUserAgent: "default-user-agent",
+ },
+ want: EncodeHeadersResult{
+ HasBody: false,
+ HasTrailers: false,
+ },
+ wantHeaders: []header{
+ {":authority", "example.tld"},
+ {":method", "HEAD"},
+ {":path", "/"},
+ {":scheme", "https"},
+ {"user-agent", "default-user-agent"},
+ {"range", "bytes=0-10"},
+ },
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ var gotHeaders []header
+ if IsRequestGzip(test.in.Request.Method, test.in.Request.Header, test.disableCompression) {
+ test.in.AddGzipHeader = true
+ }
+
+ got, err := EncodeHeaders(context.Background(), test.in, func(name, value string) {
+ gotHeaders = append(gotHeaders, header{name, value})
+ })
+ if err != nil {
+ t.Fatalf("EncodeHeaders = %v", err)
+ }
+ if got.HasBody != test.want.HasBody {
+ t.Errorf("HasBody = %v, want %v", got.HasBody, test.want.HasBody)
+ }
+ if got.HasTrailers != test.want.HasTrailers {
+ t.Errorf("HasTrailers = %v, want %v", got.HasTrailers, test.want.HasTrailers)
+ }
+ cmpHeader := func(a, b header) int {
+ return cmp.Or(
+ cmp.Compare(a.name, b.name),
+ cmp.Compare(a.value, b.value),
+ )
+ }
+ slices.SortFunc(gotHeaders, cmpHeader)
+ slices.SortFunc(test.wantHeaders, cmpHeader)
+ if !slices.Equal(gotHeaders, test.wantHeaders) {
+ t.Errorf("got headers:")
+ for _, h := range gotHeaders {
+ t.Errorf(" %v: %q", h.name, h.value)
+ }
+ t.Errorf("want headers:")
+ for _, h := range test.wantHeaders {
+ t.Errorf(" %v: %q", h.name, h.value)
+ }
+ }
+ })
+ }
+}
+
+func TestEncodeHeaderErrors(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ in EncodeHeadersParam
+ want string
+ }{{
+ name: "URL is nil",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.URL = nil
+ return req
+ }),
+ },
+ want: "URL is nil",
+ }, {
+ name: "upgrade header is set",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Upgrade", "foo")
+ return req
+ }),
+ },
+ want: "Upgrade",
+ }, {
+ name: "unsupported transfer-encoding header",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Transfer-Encoding", "identity")
+ return req
+ }),
+ },
+ want: "Transfer-Encoding",
+ }, {
+ name: "unsupported connection header",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("Connection", "x")
+ return req
+ }),
+ },
+ want: "Connection",
+ }, {
+ name: "invalid host",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Host = "\x00.tld"
+ return req
+ }),
+ },
+ want: "Host",
+ }, {
+ name: "protocol header is set",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set(":protocol", "foo")
+ return req
+ }),
+ },
+ want: ":protocol",
+ }, {
+ name: "invalid path",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.URL.Path = "no_leading_slash"
+ return req
+ }),
+ },
+ want: "path",
+ }, {
+ name: "invalid header name",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("x\ny", "foo")
+ return req
+ }),
+ },
+ want: "header",
+ }, {
+ name: "invalid header value",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("x", "foo\nbar")
+ return req
+ }),
+ },
+ want: "header",
+ }, {
+ name: "invalid trailer",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("x\ny", "foo")
+ return req
+ }),
+ },
+ want: "trailer",
+ }, {
+ name: "transfer-encoding trailer",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("Transfer-Encoding", "chunked")
+ return req
+ }),
+ },
+ want: "Trailer",
+ }, {
+ name: "trailer trailer",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("Trailer", "chunked")
+ return req
+ }),
+ },
+ want: "Trailer",
+ }, {
+ name: "content-length trailer",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Trailer = make(http.Header)
+ req.Trailer.Set("Content-Length", "0")
+ return req
+ }),
+ },
+ want: "Trailer",
+ }, {
+ name: "too many headers",
+ in: EncodeHeadersParam{
+ Request: newReq(func() *http.Request {
+ req := must(http.NewRequest("GET", "https://example.tld/", nil))
+ req.Header.Set("X-Foo", strings.Repeat("x", 1000))
+ return req
+ }),
+ PeerMaxHeaderListSize: 1000,
+ },
+ want: "limit",
+ }} {
+ t.Run(test.name, func(t *testing.T) {
+ _, err := EncodeHeaders(context.Background(), test.in, func(name, value string) {})
+ if err == nil {
+ t.Fatalf("EncodeHeaders = nil, want %q", test.want)
+ }
+ if !strings.Contains(err.Error(), test.want) {
+ t.Fatalf("EncodeHeaders = %q, want error containing %q", err, test.want)
+ }
+ })
+ }
+}
+
+func newReq(f func() *http.Request) Request {
+ req := f()
+ contentLength := req.ContentLength
+ if req.Body == nil || req.Body == http.NoBody {
+ contentLength = 0
+ } else if contentLength == 0 {
+ contentLength = -1
+ }
+ return Request{
+ Header: req.Header,
+ Trailer: req.Trailer,
+ URL: req.URL,
+ Host: req.Host,
+ Method: req.Method,
+ ActualContentLength: contentLength,
+ }
+}
+
+func must[T any](v T, err error) T {
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
diff --git a/quic/wire.go b/internal/quic/quicwire/wire.go
similarity index 65%
rename from quic/wire.go
rename to internal/quic/quicwire/wire.go
index 8486029151..0edf42227d 100644
--- a/quic/wire.go
+++ b/internal/quic/quicwire/wire.go
@@ -4,20 +4,22 @@
//go:build go1.21
-package quic
+// Package quicwire encodes and decode QUIC/HTTP3 wire encoding types,
+// particularly variable-length integers.
+package quicwire
import "encoding/binary"
const (
- maxVarintSize = 8 // encoded size in bytes
- maxVarint = (1 << 62) - 1
+ MaxVarintSize = 8 // encoded size in bytes
+ MaxVarint = (1 << 62) - 1
)
-// consumeVarint parses a variable-length integer, reporting its length.
+// ConsumeVarint parses a variable-length integer, reporting its length.
// It returns a negative length upon an error.
//
// https://www.rfc-editor.org/rfc/rfc9000.html#section-16
-func consumeVarint(b []byte) (v uint64, n int) {
+func ConsumeVarint(b []byte) (v uint64, n int) {
if len(b) < 1 {
return 0, -1
}
@@ -44,17 +46,17 @@ func consumeVarint(b []byte) (v uint64, n int) {
return 0, -1
}
-// consumeVarint64 parses a variable-length integer as an int64.
-func consumeVarintInt64(b []byte) (v int64, n int) {
- u, n := consumeVarint(b)
+// consumeVarintInt64 parses a variable-length integer as an int64.
+func ConsumeVarintInt64(b []byte) (v int64, n int) {
+ u, n := ConsumeVarint(b)
// QUIC varints are 62-bits large, so this conversion can never overflow.
return int64(u), n
}
-// appendVarint appends a variable-length integer to b.
+// AppendVarint appends a variable-length integer to b.
//
// https://www.rfc-editor.org/rfc/rfc9000.html#section-16
-func appendVarint(b []byte, v uint64) []byte {
+func AppendVarint(b []byte, v uint64) []byte {
switch {
case v <= 63:
return append(b, byte(v))
@@ -69,8 +71,8 @@ func appendVarint(b []byte, v uint64) []byte {
}
}
-// sizeVarint returns the size of the variable-length integer encoding of f.
-func sizeVarint(v uint64) int {
+// SizeVarint returns the size of the variable-length integer encoding of f.
+func SizeVarint(v uint64) int {
switch {
case v <= 63:
return 1
@@ -85,28 +87,28 @@ func sizeVarint(v uint64) int {
}
}
-// consumeUint32 parses a 32-bit fixed-length, big-endian integer, reporting its length.
+// ConsumeUint32 parses a 32-bit fixed-length, big-endian integer, reporting its length.
// It returns a negative length upon an error.
-func consumeUint32(b []byte) (uint32, int) {
+func ConsumeUint32(b []byte) (uint32, int) {
if len(b) < 4 {
return 0, -1
}
return binary.BigEndian.Uint32(b), 4
}
-// consumeUint64 parses a 64-bit fixed-length, big-endian integer, reporting its length.
+// ConsumeUint64 parses a 64-bit fixed-length, big-endian integer, reporting its length.
// It returns a negative length upon an error.
-func consumeUint64(b []byte) (uint64, int) {
+func ConsumeUint64(b []byte) (uint64, int) {
if len(b) < 8 {
return 0, -1
}
return binary.BigEndian.Uint64(b), 8
}
-// consumeUint8Bytes parses a sequence of bytes prefixed with an 8-bit length,
+// ConsumeUint8Bytes parses a sequence of bytes prefixed with an 8-bit length,
// reporting the total number of bytes consumed.
// It returns a negative length upon an error.
-func consumeUint8Bytes(b []byte) ([]byte, int) {
+func ConsumeUint8Bytes(b []byte) ([]byte, int) {
if len(b) < 1 {
return nil, -1
}
@@ -118,8 +120,8 @@ func consumeUint8Bytes(b []byte) ([]byte, int) {
return b[n:][:size], size + n
}
-// appendUint8Bytes appends a sequence of bytes prefixed by an 8-bit length.
-func appendUint8Bytes(b, v []byte) []byte {
+// AppendUint8Bytes appends a sequence of bytes prefixed by an 8-bit length.
+func AppendUint8Bytes(b, v []byte) []byte {
if len(v) > 0xff {
panic("uint8-prefixed bytes too large")
}
@@ -128,11 +130,11 @@ func appendUint8Bytes(b, v []byte) []byte {
return b
}
-// consumeVarintBytes parses a sequence of bytes preceded by a variable-length integer length,
+// ConsumeVarintBytes parses a sequence of bytes preceded by a variable-length integer length,
// reporting the total number of bytes consumed.
// It returns a negative length upon an error.
-func consumeVarintBytes(b []byte) ([]byte, int) {
- size, n := consumeVarint(b)
+func ConsumeVarintBytes(b []byte) ([]byte, int) {
+ size, n := ConsumeVarint(b)
if n < 0 {
return nil, -1
}
@@ -142,9 +144,9 @@ func consumeVarintBytes(b []byte) ([]byte, int) {
return b[n:][:size], int(size) + n
}
-// appendVarintBytes appends a sequence of bytes prefixed by a variable-length integer length.
-func appendVarintBytes(b, v []byte) []byte {
- b = appendVarint(b, uint64(len(v)))
+// AppendVarintBytes appends a sequence of bytes prefixed by a variable-length integer length.
+func AppendVarintBytes(b, v []byte) []byte {
+ b = AppendVarint(b, uint64(len(v)))
b = append(b, v...)
return b
}
diff --git a/quic/wire_test.go b/internal/quic/quicwire/wire_test.go
similarity index 73%
rename from quic/wire_test.go
rename to internal/quic/quicwire/wire_test.go
index 379da0d349..9167a5b72f 100644
--- a/quic/wire_test.go
+++ b/internal/quic/quicwire/wire_test.go
@@ -4,7 +4,7 @@
//go:build go1.21
-package quic
+package quicwire
import (
"bytes"
@@ -32,22 +32,22 @@ func TestConsumeVarint(t *testing.T) {
{[]byte{0x25}, 37, 1},
{[]byte{0x40, 0x25}, 37, 2},
} {
- got, gotLen := consumeVarint(test.b)
+ got, gotLen := ConsumeVarint(test.b)
if got != test.want || gotLen != test.wantLen {
- t.Errorf("consumeVarint(%x) = %v, %v; want %v, %v", test.b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeVarint(%x) = %v, %v; want %v, %v", test.b, got, gotLen, test.want, test.wantLen)
}
// Extra data in the buffer is ignored.
b := append(test.b, 0)
- got, gotLen = consumeVarint(b)
+ got, gotLen = ConsumeVarint(b)
if got != test.want || gotLen != test.wantLen {
- t.Errorf("consumeVarint(%x) = %v, %v; want %v, %v", b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeVarint(%x) = %v, %v; want %v, %v", b, got, gotLen, test.want, test.wantLen)
}
// Short buffer results in an error.
for i := 1; i <= len(test.b); i++ {
b = test.b[:len(test.b)-i]
- got, gotLen = consumeVarint(b)
+ got, gotLen = ConsumeVarint(b)
if got != 0 || gotLen >= 0 {
- t.Errorf("consumeVarint(%x) = %v, %v; want 0, -1", b, got, gotLen)
+ t.Errorf("ConsumeVarint(%x) = %v, %v; want 0, -1", b, got, gotLen)
}
}
}
@@ -69,11 +69,11 @@ func TestAppendVarint(t *testing.T) {
{15293, []byte{0x7b, 0xbd}},
{37, []byte{0x25}},
} {
- got := appendVarint([]byte{}, test.v)
+ got := AppendVarint([]byte{}, test.v)
if !bytes.Equal(got, test.want) {
t.Errorf("AppendVarint(nil, %v) = %x, want %x", test.v, got, test.want)
}
- if gotLen, wantLen := sizeVarint(test.v), len(got); gotLen != wantLen {
+ if gotLen, wantLen := SizeVarint(test.v), len(got); gotLen != wantLen {
t.Errorf("SizeVarint(%v) = %v, want %v", test.v, gotLen, wantLen)
}
}
@@ -88,8 +88,8 @@ func TestConsumeUint32(t *testing.T) {
{[]byte{0x01, 0x02, 0x03, 0x04}, 0x01020304, 4},
{[]byte{0x01, 0x02, 0x03}, 0, -1},
} {
- if got, n := consumeUint32(test.b); got != test.want || n != test.wantLen {
- t.Errorf("consumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen)
+ if got, n := ConsumeUint32(test.b); got != test.want || n != test.wantLen {
+ t.Errorf("ConsumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen)
}
}
}
@@ -103,8 +103,8 @@ func TestConsumeUint64(t *testing.T) {
{[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, 0x0102030405060708, 8},
{[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}, 0, -1},
} {
- if got, n := consumeUint64(test.b); got != test.want || n != test.wantLen {
- t.Errorf("consumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen)
+ if got, n := ConsumeUint64(test.b); got != test.want || n != test.wantLen {
+ t.Errorf("ConsumeUint32(%x) = %v, %v; want %v, %v", test.b, got, n, test.want, test.wantLen)
}
}
}
@@ -120,22 +120,22 @@ func TestConsumeVarintBytes(t *testing.T) {
{[]byte{0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 5},
{[]byte{0x40, 0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 6},
} {
- got, gotLen := consumeVarintBytes(test.b)
+ got, gotLen := ConsumeVarintBytes(test.b)
if !bytes.Equal(got, test.want) || gotLen != test.wantLen {
- t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen)
}
// Extra data in the buffer is ignored.
b := append(test.b, 0)
- got, gotLen = consumeVarintBytes(b)
+ got, gotLen = ConsumeVarintBytes(b)
if !bytes.Equal(got, test.want) || gotLen != test.wantLen {
- t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeVarintBytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen)
}
// Short buffer results in an error.
for i := 1; i <= len(test.b); i++ {
b = test.b[:len(test.b)-i]
- got, gotLen := consumeVarintBytes(b)
+ got, gotLen := ConsumeVarintBytes(b)
if len(got) > 0 || gotLen > 0 {
- t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
+ t.Errorf("ConsumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
}
}
@@ -147,9 +147,9 @@ func TestConsumeVarintBytesErrors(t *testing.T) {
{0x01},
{0x40, 0x01},
} {
- got, gotLen := consumeVarintBytes(b)
+ got, gotLen := ConsumeVarintBytes(b)
if len(got) > 0 || gotLen > 0 {
- t.Errorf("consumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
+ t.Errorf("ConsumeVarintBytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
}
}
}
@@ -164,22 +164,22 @@ func TestConsumeUint8Bytes(t *testing.T) {
{[]byte{0x01, 0x00}, []byte{0x00}, 2},
{[]byte{0x04, 0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, 5},
} {
- got, gotLen := consumeUint8Bytes(test.b)
+ got, gotLen := ConsumeUint8Bytes(test.b)
if !bytes.Equal(got, test.want) || gotLen != test.wantLen {
- t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", test.b, got, gotLen, test.want, test.wantLen)
}
// Extra data in the buffer is ignored.
b := append(test.b, 0)
- got, gotLen = consumeUint8Bytes(b)
+ got, gotLen = ConsumeUint8Bytes(b)
if !bytes.Equal(got, test.want) || gotLen != test.wantLen {
- t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen)
+ t.Errorf("ConsumeUint8Bytes(%x) = {%x}, %v; want {%x}, %v", b, got, gotLen, test.want, test.wantLen)
}
// Short buffer results in an error.
for i := 1; i <= len(test.b); i++ {
b = test.b[:len(test.b)-i]
- got, gotLen := consumeUint8Bytes(b)
+ got, gotLen := ConsumeUint8Bytes(b)
if len(got) > 0 || gotLen > 0 {
- t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
+ t.Errorf("ConsumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
}
}
@@ -191,35 +191,35 @@ func TestConsumeUint8BytesErrors(t *testing.T) {
{0x01},
{0x04, 0x01, 0x02, 0x03},
} {
- got, gotLen := consumeUint8Bytes(b)
+ got, gotLen := ConsumeUint8Bytes(b)
if len(got) > 0 || gotLen > 0 {
- t.Errorf("consumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
+ t.Errorf("ConsumeUint8Bytes(%x) = {%x}, %v; want {}, -1", b, got, gotLen)
}
}
}
func TestAppendUint8Bytes(t *testing.T) {
var got []byte
- got = appendUint8Bytes(got, []byte{})
- got = appendUint8Bytes(got, []byte{0xaa, 0xbb})
+ got = AppendUint8Bytes(got, []byte{})
+ got = AppendUint8Bytes(got, []byte{0xaa, 0xbb})
want := []byte{
0x00,
0x02, 0xaa, 0xbb,
}
if !bytes.Equal(got, want) {
- t.Errorf("appendUint8Bytes {}, {aabb} = {%x}; want {%x}", got, want)
+ t.Errorf("AppendUint8Bytes {}, {aabb} = {%x}; want {%x}", got, want)
}
}
func TestAppendVarintBytes(t *testing.T) {
var got []byte
- got = appendVarintBytes(got, []byte{})
- got = appendVarintBytes(got, []byte{0xaa, 0xbb})
+ got = AppendVarintBytes(got, []byte{})
+ got = AppendVarintBytes(got, []byte{0xaa, 0xbb})
want := []byte{
0x00,
0x02, 0xaa, 0xbb,
}
if !bytes.Equal(got, want) {
- t.Errorf("appendVarintBytes {}, {aabb} = {%x}; want {%x}", got, want)
+ t.Errorf("AppendVarintBytes {}, {aabb} = {%x}; want {%x}", got, want)
}
}
diff --git a/internal/socket/socket_test.go b/internal/socket/socket_test.go
index 44c196b014..26077a7a5b 100644
--- a/internal/socket/socket_test.go
+++ b/internal/socket/socket_test.go
@@ -445,11 +445,7 @@ func main() {
if runtime.Compiler == "gccgo" {
t.Skip("skipping race test when built with gccgo")
}
- dir, err := os.MkdirTemp("", "testrace")
- if err != nil {
- t.Fatalf("failed to create temp directory: %v", err)
- }
- defer os.RemoveAll(dir)
+ dir := t.TempDir()
goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
t.Logf("%s version", goBinary)
got, err := exec.Command(goBinary, "version").CombinedOutput()
diff --git a/internal/socket/zsys_openbsd_ppc64.go b/internal/socket/zsys_openbsd_ppc64.go
index cebde7634f..3c9576e2d8 100644
--- a/internal/socket/zsys_openbsd_ppc64.go
+++ b/internal/socket/zsys_openbsd_ppc64.go
@@ -4,27 +4,27 @@
package socket
type iovec struct {
- Base *byte
- Len uint64
+ Base *byte
+ Len uint64
}
type msghdr struct {
- Name *byte
- Namelen uint32
- Iov *iovec
- Iovlen uint32
- Control *byte
- Controllen uint32
- Flags int32
+ Name *byte
+ Namelen uint32
+ Iov *iovec
+ Iovlen uint32
+ Control *byte
+ Controllen uint32
+ Flags int32
}
type cmsghdr struct {
- Len uint32
- Level int32
- Type int32
+ Len uint32
+ Level int32
+ Type int32
}
const (
- sizeofIovec = 0x10
- sizeofMsghdr = 0x30
+ sizeofIovec = 0x10
+ sizeofMsghdr = 0x30
)
diff --git a/internal/socket/zsys_openbsd_riscv64.go b/internal/socket/zsys_openbsd_riscv64.go
index cebde7634f..3c9576e2d8 100644
--- a/internal/socket/zsys_openbsd_riscv64.go
+++ b/internal/socket/zsys_openbsd_riscv64.go
@@ -4,27 +4,27 @@
package socket
type iovec struct {
- Base *byte
- Len uint64
+ Base *byte
+ Len uint64
}
type msghdr struct {
- Name *byte
- Namelen uint32
- Iov *iovec
- Iovlen uint32
- Control *byte
- Controllen uint32
- Flags int32
+ Name *byte
+ Namelen uint32
+ Iov *iovec
+ Iovlen uint32
+ Control *byte
+ Controllen uint32
+ Flags int32
}
type cmsghdr struct {
- Len uint32
- Level int32
- Type int32
+ Len uint32
+ Level int32
+ Type int32
}
const (
- sizeofIovec = 0x10
- sizeofMsghdr = 0x30
+ sizeofIovec = 0x10
+ sizeofMsghdr = 0x30
)
diff --git a/internal/testcert/testcert.go b/internal/testcert/testcert.go
new file mode 100644
index 0000000000..4d8ae33bba
--- /dev/null
+++ b/internal/testcert/testcert.go
@@ -0,0 +1,36 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package testcert contains a test-only localhost certificate.
+package testcert
+
+import (
+ "strings"
+)
+
+// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
+// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
+// generated from src/crypto/tls:
+// go run generate_cert.go --ecdsa-curve P256 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var LocalhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIBrDCCAVKgAwIBAgIPCvPhO+Hfv+NW76kWxULUMAoGCCqGSM49BAMCMBIxEDAO
+BgNVBAoTB0FjbWUgQ28wIBcNNzAwMTAxMDAwMDAwWhgPMjA4NDAxMjkxNjAwMDBa
+MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARh
+WRF8p8X9scgW7JjqAwI9nYV8jtkdhqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGms
+PyfMPe5Jrha/LmjgR1G9o4GIMIGFMA4GA1UdDwEB/wQEAwIChDATBgNVHSUEDDAK
+BggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBSOJri/wLQxq6oC
+Y6ZImms/STbTljAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAA
+AAAAAAAAAAAAATAKBggqhkjOPQQDAgNIADBFAiBUguxsW6TGhixBAdORmVNnkx40
+HjkKwncMSDbUaeL9jQIhAJwQ8zV9JpQvYpsiDuMmqCuW35XXil3cQ6Drz82c+fvE
+-----END CERTIFICATE-----`)
+
+// LocalhostKey is the private key for localhostCert.
+var LocalhostKey = []byte(testingKey(`-----BEGIN TESTING KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgY1B1eL/Bbwf/MDcs
+rnvvWhFNr1aGmJJR59PdCN9lVVqhRANCAARhWRF8p8X9scgW7JjqAwI9nYV8jtkd
+hqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGmsPyfMPe5Jrha/LmjgR1G9
+-----END TESTING KEY-----`))
+
+// testingKey helps keep security scanners from getting excited about a private key in this file.
+func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
diff --git a/proxy/per_host.go b/proxy/per_host.go
index d7d4b8b6e3..32bdf435ec 100644
--- a/proxy/per_host.go
+++ b/proxy/per_host.go
@@ -7,6 +7,7 @@ package proxy
import (
"context"
"net"
+ "net/netip"
"strings"
)
@@ -57,7 +58,8 @@ func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.
}
func (p *PerHost) dialerForRequest(host string) Dialer {
- if ip := net.ParseIP(host); ip != nil {
+ if nip, err := netip.ParseAddr(host); err == nil {
+ ip := net.IP(nip.AsSlice())
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
@@ -108,8 +110,8 @@ func (p *PerHost) AddFromString(s string) {
}
continue
}
- if ip := net.ParseIP(host); ip != nil {
- p.AddIP(ip)
+ if nip, err := netip.ParseAddr(host); err == nil {
+ p.AddIP(net.IP(nip.AsSlice()))
continue
}
if strings.HasPrefix(host, "*.") {
diff --git a/proxy/per_host_test.go b/proxy/per_host_test.go
index 0447eb427a..b7bcec8ae3 100644
--- a/proxy/per_host_test.go
+++ b/proxy/per_host_test.go
@@ -7,8 +7,9 @@ package proxy
import (
"context"
"errors"
+ "fmt"
"net"
- "reflect"
+ "slices"
"testing"
)
@@ -22,55 +23,118 @@ func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {
}
func TestPerHost(t *testing.T) {
- expectedDef := []string{
- "example.com:123",
- "1.2.3.4:123",
- "[1001::]:123",
- }
- expectedBypass := []string{
- "localhost:123",
- "zone:123",
- "foo.zone:123",
- "127.0.0.1:123",
- "10.1.2.3:123",
- "[1000::]:123",
- }
-
- t.Run("Dial", func(t *testing.T) {
- var def, bypass recordingProxy
- perHost := NewPerHost(&def, &bypass)
- perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
- for _, addr := range expectedDef {
- perHost.Dial("tcp", addr)
+ for _, test := range []struct {
+ config string // passed to PerHost.AddFromString
+ nomatch []string // addrs using the default dialer
+ match []string // addrs using the bypass dialer
+ }{{
+ config: "localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16",
+ nomatch: []string{
+ "example.com:123",
+ "1.2.3.4:123",
+ "[1001::]:123",
+ },
+ match: []string{
+ "localhost:123",
+ "zone:123",
+ "foo.zone:123",
+ "127.0.0.1:123",
+ "10.1.2.3:123",
+ "[1000::]:123",
+ "[1000::%25.example.com]:123",
+ },
+ }, {
+ config: "localhost",
+ nomatch: []string{
+ "127.0.0.1:80",
+ },
+ match: []string{
+ "localhost:80",
+ },
+ }, {
+ config: "*.zone",
+ nomatch: []string{
+ "foo.com:80",
+ },
+ match: []string{
+ "foo.zone:80",
+ "foo.bar.zone:80",
+ },
+ }, {
+ config: "1.2.3.4",
+ nomatch: []string{
+ "127.0.0.1:80",
+ "11.2.3.4:80",
+ },
+ match: []string{
+ "1.2.3.4:80",
+ },
+ }, {
+ config: "10.0.0.0/24",
+ nomatch: []string{
+ "10.0.1.1:80",
+ },
+ match: []string{
+ "10.0.0.1:80",
+ "10.0.0.255:80",
+ },
+ }, {
+ config: "fe80::/10",
+ nomatch: []string{
+ "[fec0::1]:80",
+ "[fec0::1%en0]:80",
+ },
+ match: []string{
+ "[fe80::1]:80",
+ "[fe80::1%en0]:80",
+ },
+ }, {
+ // We don't allow zone IDs in network prefixes,
+ // so this config matches nothing.
+ config: "fe80::%en0/10",
+ nomatch: []string{
+ "[fec0::1]:80",
+ "[fec0::1%en0]:80",
+ "[fe80::1]:80",
+ "[fe80::1%en0]:80",
+ "[fe80::1%en1]:80",
+ },
+ }} {
+ for _, addr := range test.match {
+ testPerHost(t, test.config, addr, true)
}
- for _, addr := range expectedBypass {
- perHost.Dial("tcp", addr)
+ for _, addr := range test.nomatch {
+ testPerHost(t, test.config, addr, false)
}
+ }
+}
- if !reflect.DeepEqual(expectedDef, def.addrs) {
- t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
- }
- if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
- t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
- }
- })
+func testPerHost(t *testing.T, config, addr string, wantMatch bool) {
+ name := fmt.Sprintf("config %q, dial %q", config, addr)
- t.Run("DialContext", func(t *testing.T) {
- var def, bypass recordingProxy
- perHost := NewPerHost(&def, &bypass)
- perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
- for _, addr := range expectedDef {
- perHost.DialContext(context.Background(), "tcp", addr)
- }
- for _, addr := range expectedBypass {
- perHost.DialContext(context.Background(), "tcp", addr)
- }
+ var def, bypass recordingProxy
+ perHost := NewPerHost(&def, &bypass)
+ perHost.AddFromString(config)
+ perHost.Dial("tcp", addr)
- if !reflect.DeepEqual(expectedDef, def.addrs) {
- t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
- }
- if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
- t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
- }
- })
+ // Dial and DialContext should have the same results.
+ var defc, bypassc recordingProxy
+ perHostc := NewPerHost(&defc, &bypassc)
+ perHostc.AddFromString(config)
+ perHostc.DialContext(context.Background(), "tcp", addr)
+ if !slices.Equal(def.addrs, defc.addrs) {
+ t.Errorf("%v: Dial default=%v, bypass=%v; DialContext default=%v, bypass=%v", name, def.addrs, bypass.addrs, defc.addrs, bypass.addrs)
+ return
+ }
+
+ if got, want := slices.Concat(def.addrs, bypass.addrs), []string{addr}; !slices.Equal(got, want) {
+ t.Errorf("%v: dialed %q, want %q", name, got, want)
+ return
+ }
+
+ gotMatch := len(bypass.addrs) > 0
+ if gotMatch != wantMatch {
+ t.Errorf("%v: matched=%v, want %v", name, gotMatch, wantMatch)
+ return
+ }
}
diff --git a/publicsuffix/gen.go b/publicsuffix/gen.go
index 7f7d08dbc2..5f454e57e9 100644
--- a/publicsuffix/gen.go
+++ b/publicsuffix/gen.go
@@ -21,6 +21,7 @@ package main
import (
"bufio"
"bytes"
+ "cmp"
"encoding/binary"
"flag"
"fmt"
@@ -29,7 +30,7 @@ import (
"net/http"
"os"
"regexp"
- "sort"
+ "slices"
"strings"
"golang.org/x/net/idna"
@@ -62,20 +63,6 @@ var (
maxLo uint32
)
-func max(a, b int) int {
- if a < b {
- return b
- }
- return a
-}
-
-func u32max(a, b uint32) uint32 {
- if a < b {
- return b
- }
- return a
-}
-
const (
nodeTypeNormal = 0
nodeTypeException = 1
@@ -83,18 +70,6 @@ const (
numNodeType = 3
)
-func nodeTypeStr(n int) string {
- switch n {
- case nodeTypeNormal:
- return "+"
- case nodeTypeException:
- return "!"
- case nodeTypeParentOnly:
- return "o"
- }
- panic("unreachable")
-}
-
const (
defaultURL = "https://publicsuffix.org/list/effective_tld_names.dat"
gitCommitURL = "https://api.github.com/repos/publicsuffix/list/commits?path=public_suffix_list.dat"
@@ -251,7 +226,7 @@ func main1() error {
for label := range labelsMap {
labelsList = append(labelsList, label)
}
- sort.Strings(labelsList)
+ slices.Sort(labelsList)
combinedText = combineText(labelsList)
if combinedText == "" {
@@ -509,15 +484,13 @@ func (n *node) child(label string) *node {
icann: true,
}
n.children = append(n.children, c)
- sort.Sort(byLabel(n.children))
+ slices.SortFunc(n.children, byLabel)
return c
}
-type byLabel []*node
-
-func (b byLabel) Len() int { return len(b) }
-func (b byLabel) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
-func (b byLabel) Less(i, j int) bool { return b[i].label < b[j].label }
+func byLabel(a, b *node) int {
+ return strings.Compare(a.label, b.label)
+}
var nextNodesIndex int
@@ -557,7 +530,7 @@ func assignIndexes(n *node) error {
n.childrenIndex = len(childrenEncoding)
lo := uint32(n.firstChild)
hi := lo + uint32(len(n.children))
- maxLo, maxHi = u32max(maxLo, lo), u32max(maxHi, hi)
+ maxLo, maxHi = max(maxLo, lo), max(maxHi, hi)
if lo >= 1< 0 && ss[0] == "" {
ss = ss[1:]
}
diff --git a/publicsuffix/list.go b/publicsuffix/list.go
index d56e9e7624..56069d0429 100644
--- a/publicsuffix/list.go
+++ b/publicsuffix/list.go
@@ -88,7 +88,7 @@ func PublicSuffix(domain string) (publicSuffix string, icann bool) {
s, suffix, icannNode, wildcard := domain, len(domain), false, false
loop:
for {
- dot := strings.LastIndex(s, ".")
+ dot := strings.LastIndexByte(s, '.')
if wildcard {
icann = icannNode
suffix = 1 + dot
@@ -129,7 +129,7 @@ loop:
}
if suffix == len(domain) {
// If no rules match, the prevailing rule is "*".
- return domain[1+strings.LastIndex(domain, "."):], icann
+ return domain[1+strings.LastIndexByte(domain, '.'):], icann
}
return domain[suffix:], icann
}
@@ -178,26 +178,28 @@ func EffectiveTLDPlusOne(domain string) (string, error) {
if domain[i] != '.' {
return "", fmt.Errorf("publicsuffix: invalid public suffix %q for domain %q", suffix, domain)
}
- return domain[1+strings.LastIndex(domain[:i], "."):], nil
+ return domain[1+strings.LastIndexByte(domain[:i], '.'):], nil
}
type uint32String string
func (u uint32String) get(i uint32) uint32 {
off := i * 4
- return (uint32(u[off])<<24 |
- uint32(u[off+1])<<16 |
- uint32(u[off+2])<<8 |
- uint32(u[off+3]))
+ u = u[off:] // help the compiler reduce bounds checks
+ return uint32(u[3]) |
+ uint32(u[2])<<8 |
+ uint32(u[1])<<16 |
+ uint32(u[0])<<24
}
type uint40String string
func (u uint40String) get(i uint32) uint64 {
off := uint64(i * (nodesBits / 8))
- return uint64(u[off])<<32 |
- uint64(u[off+1])<<24 |
- uint64(u[off+2])<<16 |
- uint64(u[off+3])<<8 |
- uint64(u[off+4])
+ u = u[off:] // help the compiler reduce bounds checks
+ return uint64(u[4]) |
+ uint64(u[3])<<8 |
+ uint64(u[2])<<16 |
+ uint64(u[1])<<24 |
+ uint64(u[0])<<32
}
diff --git a/quic/config.go b/quic/config.go
index 5d420312bb..d6aa87730f 100644
--- a/quic/config.go
+++ b/quic/config.go
@@ -11,6 +11,8 @@ import (
"log/slog"
"math"
"time"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// A Config structure configures a QUIC endpoint.
@@ -134,15 +136,15 @@ func (c *Config) maxUniRemoteStreams() int64 {
}
func (c *Config) maxStreamReadBufferSize() int64 {
- return configDefault(c.MaxStreamReadBufferSize, 1<<20, maxVarint)
+ return configDefault(c.MaxStreamReadBufferSize, 1<<20, quicwire.MaxVarint)
}
func (c *Config) maxStreamWriteBufferSize() int64 {
- return configDefault(c.MaxStreamWriteBufferSize, 1<<20, maxVarint)
+ return configDefault(c.MaxStreamWriteBufferSize, 1<<20, quicwire.MaxVarint)
}
func (c *Config) maxConnReadBufferSize() int64 {
- return configDefault(c.MaxConnReadBufferSize, 1<<20, maxVarint)
+ return configDefault(c.MaxConnReadBufferSize, 1<<20, quicwire.MaxVarint)
}
func (c *Config) handshakeTimeout() time.Duration {
diff --git a/quic/conn.go b/quic/conn.go
index 38e8fe8f4e..1f1cfa6d0a 100644
--- a/quic/conn.go
+++ b/quic/conn.go
@@ -176,6 +176,21 @@ func (c *Conn) String() string {
return fmt.Sprintf("quic.Conn(%v,->%v)", c.side, c.peerAddr)
}
+// LocalAddr returns the local network address, if known.
+func (c *Conn) LocalAddr() netip.AddrPort {
+ return c.localAddr
+}
+
+// RemoteAddr returns the remote network address, if known.
+func (c *Conn) RemoteAddr() netip.AddrPort {
+ return c.peerAddr
+}
+
+// ConnectionState returns basic TLS details about the connection.
+func (c *Conn) ConnectionState() tls.ConnectionState {
+ return c.tls.ConnectionState()
+}
+
// confirmHandshake is called when the handshake is confirmed.
// https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2
func (c *Conn) confirmHandshake(now time.Time) {
@@ -206,6 +221,9 @@ func (c *Conn) confirmHandshake(now time.Time) {
// discardKeys discards unused packet protection keys.
// https://www.rfc-editor.org/rfc/rfc9001#section-4.9
func (c *Conn) discardKeys(now time.Time, space numberSpace) {
+ if err := c.crypto[space].discardKeys(); err != nil {
+ c.abort(now, err)
+ }
switch space {
case initialSpace:
c.keysInitial.discard()
diff --git a/quic/conn_close.go b/quic/conn_close.go
index 1798d0536f..cd8d7e3c5a 100644
--- a/quic/conn_close.go
+++ b/quic/conn_close.go
@@ -178,7 +178,7 @@ func (c *Conn) sendOK(now time.Time) bool {
}
}
-// sendConnectionClose reports that the conn has sent a CONNECTION_CLOSE to the peer.
+// sentConnectionClose reports that the conn has sent a CONNECTION_CLOSE to the peer.
func (c *Conn) sentConnectionClose(now time.Time) {
switch c.lifetime.state {
case connStatePeerClosed:
@@ -230,6 +230,17 @@ func (c *Conn) setFinalError(err error) {
close(c.lifetime.donec)
}
+// finalError returns the final connection status reported to the user,
+// or nil if a final status has not yet been set.
+func (c *Conn) finalError() error {
+ select {
+ case <-c.lifetime.donec:
+ return c.lifetime.finalErr
+ default:
+ }
+ return nil
+}
+
func (c *Conn) waitReady(ctx context.Context) error {
select {
case <-c.lifetime.readyc:
diff --git a/quic/conn_id.go b/quic/conn_id.go
index 2efe8d6b5d..2d50f14fa6 100644
--- a/quic/conn_id.go
+++ b/quic/conn_id.go
@@ -9,6 +9,7 @@ package quic
import (
"bytes"
"crypto/rand"
+ "slices"
)
// connIDState is a conn's connection IDs.
@@ -25,8 +26,16 @@ type connIDState struct {
remote []remoteConnID
nextLocalSeq int64
- retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer
- peerActiveConnIDLimit int64 // peer's active_connection_id_limit transport parameter
+ peerActiveConnIDLimit int64 // peer's active_connection_id_limit
+
+ // Handling of retirement of remote connection IDs.
+ // The rangesets track ID sequence numbers.
+ // IDs in need of retirement are added to remoteRetiring,
+ // moved to remoteRetiringSent once we send a RETIRE_CONECTION_ID frame,
+ // and removed from the set once retirement completes.
+ retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer
+ remoteRetiring rangeset[int64] // remote IDs in need of retirement
+ remoteRetiringSent rangeset[int64] // remote IDs waiting for ack of retirement
originalDstConnID []byte // expected original_destination_connection_id param
retrySrcConnID []byte // expected retry_source_connection_id param
@@ -45,9 +54,6 @@ type connID struct {
// For the transient destination ID in a client's Initial packet, this is -1.
seq int64
- // retired is set when the connection ID is retired.
- retired bool
-
// send is set when the connection ID's state needs to be sent to the peer.
//
// For local IDs, this indicates a new ID that should be sent
@@ -144,9 +150,7 @@ func (s *connIDState) srcConnID() []byte {
// dstConnID is the Destination Connection ID to use in a sent packet.
func (s *connIDState) dstConnID() (cid []byte, ok bool) {
for i := range s.remote {
- if !s.remote[i].retired {
- return s.remote[i].cid, true
- }
+ return s.remote[i].cid, true
}
return nil, false
}
@@ -154,14 +158,12 @@ func (s *connIDState) dstConnID() (cid []byte, ok bool) {
// isValidStatelessResetToken reports whether the given reset token is
// associated with a non-retired connection ID which we have used.
func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool {
- for i := range s.remote {
- // We currently only use the first available remote connection ID,
- // so any other reset token is not valid.
- if !s.remote[i].retired {
- return s.remote[i].resetToken == resetToken
- }
+ if len(s.remote) == 0 {
+ return false
}
- return false
+ // We currently only use the first available remote connection ID,
+ // so any other reset token is not valid.
+ return s.remote[0].resetToken == resetToken
}
// setPeerActiveConnIDLimit sets the active_connection_id_limit
@@ -174,7 +176,7 @@ func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error {
func (s *connIDState) issueLocalIDs(c *Conn) error {
toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit)
for i := range s.local {
- if s.local[i].seq != -1 && !s.local[i].retired {
+ if s.local[i].seq != -1 {
toIssue--
}
}
@@ -271,7 +273,7 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte)
}
}
case ptype == packetTypeHandshake && c.side == serverSide:
- if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired {
+ if len(s.local) > 0 && s.local[0].seq == -1 {
// We're a server connection processing the first Handshake packet from
// the client. Discard the transient, client-chosen connection ID used
// for Initial packets; the client will never send it again.
@@ -304,23 +306,29 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
}
}
+ if seq < s.retireRemotePriorTo {
+ // This ID was already retired by a previous NEW_CONNECTION_ID frame.
+ // Nothing to do.
+ return nil
+ }
+
if retire > s.retireRemotePriorTo {
+ // Add newly-retired connection IDs to the set we need to send
+ // RETIRE_CONNECTION_ID frames for, and remove them from s.remote.
+ //
+ // (This might cause us to send a RETIRE_CONNECTION_ID for an ID we've
+ // never seen. That's fine.)
+ s.remoteRetiring.add(s.retireRemotePriorTo, retire)
s.retireRemotePriorTo = retire
+ s.needSend = true
+ s.remote = slices.DeleteFunc(s.remote, func(rcid remoteConnID) bool {
+ return rcid.seq < s.retireRemotePriorTo
+ })
}
have := false // do we already have this connection ID?
- active := 0
for i := range s.remote {
rcid := &s.remote[i]
- if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo {
- s.retireRemote(rcid)
- c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
- conns.retireResetToken(c, rcid.resetToken)
- })
- }
- if !rcid.retired {
- active++
- }
if rcid.seq == seq {
if !bytes.Equal(rcid.cid, cid) {
return localTransportError{
@@ -329,6 +337,7 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
}
}
have = true // yes, we've seen this sequence number
+ break
}
}
@@ -345,18 +354,12 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
},
resetToken: resetToken,
})
- if seq < s.retireRemotePriorTo {
- // This ID was already retired by a previous NEW_CONNECTION_ID frame.
- s.retireRemote(&s.remote[len(s.remote)-1])
- } else {
- active++
- c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
- conns.addResetToken(c, resetToken)
- })
- }
+ c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) {
+ conns.addResetToken(c, resetToken)
+ })
}
- if active > activeConnIDLimit {
+ if len(s.remote) > activeConnIDLimit {
// Retired connection IDs (including newly-retired ones) do not count
// against the limit.
// https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5
@@ -370,25 +373,18 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re
// for which RETIRE_CONNECTION_ID frames have not yet been acknowledged."
// https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6
//
- // Set a limit of four times the active_connection_id_limit for
- // the total number of remote connection IDs we keep state for locally.
- if len(s.remote) > 4*activeConnIDLimit {
+ // Set a limit of three times the active_connection_id_limit for
+ // the total number of remote connection IDs we keep retirement state for.
+ if s.remoteRetiring.size()+s.remoteRetiringSent.size() > 3*activeConnIDLimit {
return localTransportError{
code: errConnectionIDLimit,
- reason: "too many unacknowledged RETIRE_CONNECTION_ID frames",
+ reason: "too many unacknowledged retired connection ids",
}
}
return nil
}
-// retireRemote marks a remote connection ID as retired.
-func (s *connIDState) retireRemote(rcid *remoteConnID) {
- rcid.retired = true
- rcid.send.setUnsent()
- s.needSend = true
-}
-
func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error {
if seq >= s.nextLocalSeq {
return localTransportError{
@@ -424,20 +420,11 @@ func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fat
}
func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) {
- for i := 0; i < len(s.remote); i++ {
- if s.remote[i].seq != seq {
- continue
- }
- if fate == packetAcked {
- // We have retired this connection ID, and the peer has acked.
- // Discard its state completely.
- s.remote = append(s.remote[:i], s.remote[i+1:]...)
- } else {
- // RETIRE_CONNECTION_ID frame was lost, mark for retransmission.
- s.needSend = true
- s.remote[i].send.ackOrLoss(pnum, fate)
- }
- return
+ s.remoteRetiringSent.sub(seq, seq+1)
+ if fate == packetLost {
+ // RETIRE_CONNECTION_ID frame was lost, mark for retransmission.
+ s.remoteRetiring.add(seq, seq+1)
+ s.needSend = true
}
}
@@ -469,14 +456,22 @@ func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool {
}
s.local[i].send.setSent(pnum)
}
- for i := range s.remote {
- if !s.remote[i].send.shouldSendPTO(pto) {
- continue
+ if pto {
+ for _, r := range s.remoteRetiringSent {
+ for cid := r.start; cid < r.end; cid++ {
+ if !c.w.appendRetireConnectionIDFrame(cid) {
+ return false
+ }
+ }
}
- if !c.w.appendRetireConnectionIDFrame(s.remote[i].seq) {
+ }
+ for s.remoteRetiring.numRanges() > 0 {
+ cid := s.remoteRetiring.min()
+ if !c.w.appendRetireConnectionIDFrame(cid) {
return false
}
- s.remote[i].send.setSent(pnum)
+ s.remoteRetiring.sub(cid, cid+1)
+ s.remoteRetiringSent.add(cid, cid+1)
}
s.needSend = false
return true
diff --git a/quic/conn_id_test.go b/quic/conn_id_test.go
index d44472e813..2c3f170160 100644
--- a/quic/conn_id_test.go
+++ b/quic/conn_id_test.go
@@ -664,3 +664,52 @@ func TestConnIDsCleanedUpAfterClose(t *testing.T) {
}
})
}
+
+func TestConnIDRetiredConnIDResent(t *testing.T) {
+ tc := newTestConn(t, serverSide)
+ tc.handshake()
+ tc.ignoreFrame(frameTypeAck)
+ //tc.ignoreFrame(frameTypeRetireConnectionID)
+
+ // Send CID 2, retire 0-1 (negotiated during the handshake).
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 2,
+ retirePriorTo: 2,
+ connID: testPeerConnID(2),
+ token: testPeerStatelessResetToken(2),
+ })
+ tc.wantFrame("retire CID 0", packetType1RTT, debugFrameRetireConnectionID{seq: 0})
+ tc.wantFrame("retire CID 1", packetType1RTT, debugFrameRetireConnectionID{seq: 1})
+
+ // Send CID 3, retire 2.
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 3,
+ retirePriorTo: 3,
+ connID: testPeerConnID(3),
+ token: testPeerStatelessResetToken(3),
+ })
+ tc.wantFrame("retire CID 2", packetType1RTT, debugFrameRetireConnectionID{seq: 2})
+
+ // Acknowledge retirement of CIDs 0-2.
+ // The server should have state for only one CID: 3.
+ tc.writeAckForAll()
+ if got, want := len(tc.conn.connIDState.remote), 1; got != want {
+ t.Fatalf("connection has state for %v connection IDs, want %v", got, want)
+ }
+
+ // Send CID 2 again.
+ // The server should ignore this, since it's already retired the CID.
+ tc.ignoreFrames[frameTypeRetireConnectionID] = false
+ tc.writeFrames(packetType1RTT,
+ debugFrameNewConnectionID{
+ seq: 2,
+ connID: testPeerConnID(2),
+ token: testPeerStatelessResetToken(2),
+ })
+ if got, want := len(tc.conn.connIDState.remote), 1; got != want {
+ t.Fatalf("connection has state for %v connection IDs, want %v", got, want)
+ }
+ tc.wantIdle("server does not re-retire already retired CID 2")
+}
diff --git a/quic/conn_recv.go b/quic/conn_recv.go
index b1354cd3a1..dbfe34a343 100644
--- a/quic/conn_recv.go
+++ b/quic/conn_recv.go
@@ -285,6 +285,7 @@ func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, sp
__01 = packetType0RTT | packetType1RTT
___1 = packetType1RTT
)
+ hasCrypto := false
for len(payload) > 0 {
switch payload[0] {
case frameTypePadding, frameTypeAck, frameTypeAckECN,
@@ -322,6 +323,7 @@ func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, sp
if !frameOK(c, ptype, IH_1) {
return
}
+ hasCrypto = true
n = c.handleCryptoFrame(now, space, payload)
case frameTypeNewToken:
if !frameOK(c, ptype, ___1) {
@@ -406,6 +408,15 @@ func (c *Conn) handleFrames(now time.Time, dgram *datagram, ptype packetType, sp
}
payload = payload[n:]
}
+ if hasCrypto {
+ // Process TLS events after handling all frames in a packet.
+ // TLS events can cause us to drop state for a number space,
+ // so do that last, to avoid handling frames differently
+ // depending on whether they come before or after a CRYPTO frame.
+ if err := c.handleTLSEvents(now); err != nil {
+ c.abort(now, err)
+ }
+ }
return ackEliciting
}
diff --git a/quic/conn_recv_test.go b/quic/conn_recv_test.go
new file mode 100644
index 0000000000..0e94731bf7
--- /dev/null
+++ b/quic/conn_recv_test.go
@@ -0,0 +1,60 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "crypto/tls"
+ "testing"
+)
+
+func TestConnReceiveAckForUnsentPacket(t *testing.T) {
+ tc := newTestConn(t, serverSide, permissiveTransportParameters)
+ tc.handshake()
+ tc.writeFrames(packetType1RTT,
+ debugFrameAck{
+ ackDelay: 0,
+ ranges: []i64range[packetNumber]{{0, 10}},
+ })
+ tc.wantFrame("ACK for unsent packet causes CONNECTION_CLOSE",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errProtocolViolation,
+ })
+}
+
+// Issue #70703: If a packet contains both a CRYPTO frame which causes us to
+// drop state for a number space, and also contains a valid ACK frame for that space,
+// we shouldn't complain about the ACK.
+func TestConnReceiveAckForDroppedSpace(t *testing.T) {
+ tc := newTestConn(t, serverSide, permissiveTransportParameters)
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("send Initial crypto",
+ packetTypeInitial, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelInitial],
+ })
+ tc.wantFrame("send Handshake crypto",
+ packetTypeHandshake, debugFrameCrypto{
+ data: tc.cryptoDataOut[tls.QUICEncryptionLevelHandshake],
+ })
+
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ },
+ debugFrameAck{
+ ackDelay: 0,
+ ranges: []i64range[packetNumber]{{0, tc.lastPacket.num + 1}},
+ })
+ tc.wantFrame("handshake finishes",
+ packetType1RTT, debugFrameHandshakeDone{})
+ tc.wantIdle("connection is idle")
+}
diff --git a/quic/conn_test.go b/quic/conn_test.go
index f4f1818a64..51402630fc 100644
--- a/quic/conn_test.go
+++ b/quic/conn_test.go
@@ -436,7 +436,7 @@ func (tc *testConn) write(d *testDatagram) {
tc.endpoint.writeDatagram(d)
}
-// writeFrame sends the Conn a datagram containing the given frames.
+// writeFrames sends the Conn a datagram containing the given frames.
func (tc *testConn) writeFrames(ptype packetType, frames ...debugFrame) {
tc.t.Helper()
space := spaceForPacketType(ptype)
diff --git a/quic/crypto_stream.go b/quic/crypto_stream.go
index a4dcb32eb7..806c963943 100644
--- a/quic/crypto_stream.go
+++ b/quic/crypto_stream.go
@@ -139,3 +139,21 @@ func (s *cryptoStream) sendData(off int64, b []byte) {
s.out.copy(off, b)
s.outunsent.sub(off, off+int64(len(b)))
}
+
+// discardKeys is called when the packet protection keys for the stream are dropped.
+func (s *cryptoStream) discardKeys() error {
+ if s.in.end-s.in.start != 0 {
+ // The peer sent some unprocessed CRYPTO data that we're about to discard.
+ // Close the connetion with a TLS unexpected_message alert.
+ // https://www.rfc-editor.org/rfc/rfc5246#section-7.2.2
+ const unexpectedMessage = 10
+ return localTransportError{
+ code: errTLSBase + unexpectedMessage,
+ reason: "excess crypto data",
+ }
+ }
+ // Discard any unacked (but presumably received) data in our output buffer.
+ s.out.discardBefore(s.out.end)
+ *s = cryptoStream{}
+ return nil
+}
diff --git a/quic/endpoint.go b/quic/endpoint.go
index a55336b240..b9ababe6b1 100644
--- a/quic/endpoint.go
+++ b/quic/endpoint.go
@@ -73,6 +73,25 @@ func Listen(network, address string, listenConfig *Config) (*Endpoint, error) {
return newEndpoint(pc, listenConfig, nil)
}
+// NewEndpoint creates an endpoint using a net.PacketConn as the underlying transport.
+//
+// If the PacketConn is not a *net.UDPConn, the endpoint may be slower and lack
+// access to some features of the network.
+func NewEndpoint(conn net.PacketConn, config *Config) (*Endpoint, error) {
+ var pc packetConn
+ var err error
+ switch conn := conn.(type) {
+ case *net.UDPConn:
+ pc, err = newNetUDPConn(conn)
+ default:
+ pc, err = newNetPacketConn(conn)
+ }
+ if err != nil {
+ return nil, err
+ }
+ return newEndpoint(pc, config, nil)
+}
+
func newEndpoint(pc packetConn, config *Config, hooks endpointTestHooks) (*Endpoint, error) {
e := &Endpoint{
listenConfig: config,
@@ -448,7 +467,7 @@ func (m *connsMap) updateConnIDs(f func(*connsMap)) {
m.updateNeeded.Store(true)
}
-// applyConnIDUpdates is called by the datagram receive loop to update its connection ID map.
+// applyUpdates is called by the datagram receive loop to update its connection ID map.
func (m *connsMap) applyUpdates() {
m.updateMu.Lock()
defer m.updateMu.Unlock()
diff --git a/quic/errors.go b/quic/errors.go
index 954793cfc0..b805b93c1b 100644
--- a/quic/errors.go
+++ b/quic/errors.go
@@ -121,8 +121,7 @@ type ApplicationError struct {
}
func (e *ApplicationError) Error() string {
- // TODO: Include the Reason string here, but sanitize it first.
- return fmt.Sprintf("AppError %v", e.Code)
+ return fmt.Sprintf("peer closed connection: %v: %q", e.Code, e.Reason)
}
// Is reports a match if err is an *ApplicationError with a matching Code.
diff --git a/quic/gate.go b/quic/gate.go
index a2fb537115..8f1db2be66 100644
--- a/quic/gate.go
+++ b/quic/gate.go
@@ -27,7 +27,7 @@ func newGate() gate {
return g
}
-// newLocked gate returns a new, locked gate.
+// newLockedGate returns a new, locked gate.
func newLockedGate() gate {
return gate{
set: make(chan struct{}, 1),
@@ -84,7 +84,7 @@ func (g *gate) unlock(set bool) {
}
}
-// unlock sets the condition to the result of f and releases the gate.
+// unlockFunc sets the condition to the result of f and releases the gate.
// Useful in defers.
func (g *gate) unlockFunc(f func() bool) {
g.unlock(f())
diff --git a/quic/packet.go b/quic/packet.go
index 7a874319d7..883754f021 100644
--- a/quic/packet.go
+++ b/quic/packet.go
@@ -9,6 +9,8 @@ package quic
import (
"encoding/binary"
"fmt"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// packetType is a QUIC packet type.
@@ -196,10 +198,10 @@ func parseVersionNegotiation(pkt []byte) (dstConnID, srcConnID, versions []byte)
// appendVersionNegotiation appends a Version Negotiation packet to pkt,
// returning the result.
func appendVersionNegotiation(pkt, dstConnID, srcConnID []byte, versions ...uint32) []byte {
- pkt = append(pkt, headerFormLong|fixedBit) // header byte
- pkt = append(pkt, 0, 0, 0, 0) // Version (0 for Version Negotiation)
- pkt = appendUint8Bytes(pkt, dstConnID) // Destination Connection ID
- pkt = appendUint8Bytes(pkt, srcConnID) // Source Connection ID
+ pkt = append(pkt, headerFormLong|fixedBit) // header byte
+ pkt = append(pkt, 0, 0, 0, 0) // Version (0 for Version Negotiation)
+ pkt = quicwire.AppendUint8Bytes(pkt, dstConnID) // Destination Connection ID
+ pkt = quicwire.AppendUint8Bytes(pkt, srcConnID) // Source Connection ID
for _, v := range versions {
pkt = binary.BigEndian.AppendUint32(pkt, v) // Supported Version
}
@@ -243,21 +245,21 @@ func parseGenericLongHeaderPacket(b []byte) (p genericLongPacket, ok bool) {
b = b[1:]
// Version (32),
var n int
- p.version, n = consumeUint32(b)
+ p.version, n = quicwire.ConsumeUint32(b)
if n < 0 {
return genericLongPacket{}, false
}
b = b[n:]
// Destination Connection ID Length (8),
// Destination Connection ID (0..2048),
- p.dstConnID, n = consumeUint8Bytes(b)
+ p.dstConnID, n = quicwire.ConsumeUint8Bytes(b)
if n < 0 || len(p.dstConnID) > 2048/8 {
return genericLongPacket{}, false
}
b = b[n:]
// Source Connection ID Length (8),
// Source Connection ID (0..2048),
- p.srcConnID, n = consumeUint8Bytes(b)
+ p.srcConnID, n = quicwire.ConsumeUint8Bytes(b)
if n < 0 || len(p.dstConnID) > 2048/8 {
return genericLongPacket{}, false
}
diff --git a/quic/packet_codec_test.go b/quic/packet_codec_test.go
index 3b39795ef5..2a2b08f4e3 100644
--- a/quic/packet_codec_test.go
+++ b/quic/packet_codec_test.go
@@ -15,6 +15,7 @@ import (
"testing"
"time"
+ "golang.org/x/net/internal/quic/quicwire"
"golang.org/x/net/quic/qlog"
)
@@ -736,7 +737,7 @@ func TestFrameDecodeErrors(t *testing.T) {
name: "MAX_STREAMS with too many streams",
b: func() []byte {
// https://www.rfc-editor.org/rfc/rfc9000.html#section-19.11-5.2.1
- return appendVarint([]byte{frameTypeMaxStreamsBidi}, (1<<60)+1)
+ return quicwire.AppendVarint([]byte{frameTypeMaxStreamsBidi}, (1<<60)+1)
}(),
}, {
name: "NEW_CONNECTION_ID too small",
diff --git a/quic/packet_parser.go b/quic/packet_parser.go
index feef9eac7f..dca3018086 100644
--- a/quic/packet_parser.go
+++ b/quic/packet_parser.go
@@ -6,6 +6,8 @@
package quic
+import "golang.org/x/net/internal/quic/quicwire"
+
// parseLongHeaderPacket parses a QUIC long header packet.
//
// It does not parse Version Negotiation packets.
@@ -34,7 +36,7 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
}
b = b[1:]
// Version (32),
- p.version, n = consumeUint32(b)
+ p.version, n = quicwire.ConsumeUint32(b)
if n < 0 {
return longPacket{}, -1
}
@@ -46,7 +48,7 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
// Destination Connection ID Length (8),
// Destination Connection ID (0..160),
- p.dstConnID, n = consumeUint8Bytes(b)
+ p.dstConnID, n = quicwire.ConsumeUint8Bytes(b)
if n < 0 || len(p.dstConnID) > maxConnIDLen {
return longPacket{}, -1
}
@@ -54,7 +56,7 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
// Source Connection ID Length (8),
// Source Connection ID (0..160),
- p.srcConnID, n = consumeUint8Bytes(b)
+ p.srcConnID, n = quicwire.ConsumeUint8Bytes(b)
if n < 0 || len(p.dstConnID) > maxConnIDLen {
return longPacket{}, -1
}
@@ -64,7 +66,7 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
case packetTypeInitial:
// Token Length (i),
// Token (..),
- p.extra, n = consumeVarintBytes(b)
+ p.extra, n = quicwire.ConsumeVarintBytes(b)
if n < 0 {
return longPacket{}, -1
}
@@ -77,7 +79,7 @@ func parseLongHeaderPacket(pkt []byte, k fixedKeys, pnumMax packetNumber) (p lon
}
// Length (i),
- payLen, n := consumeVarint(b)
+ payLen, n := quicwire.ConsumeVarint(b)
if n < 0 {
return longPacket{}, -1
}
@@ -121,14 +123,14 @@ func skipLongHeaderPacket(pkt []byte) int {
}
if getPacketType(pkt) == packetTypeInitial {
// Token length, token.
- _, nn := consumeVarintBytes(pkt[n:])
+ _, nn := quicwire.ConsumeVarintBytes(pkt[n:])
if nn < 0 {
return -1
}
n += nn
}
// Length, packet number, payload.
- _, nn := consumeVarintBytes(pkt[n:])
+ _, nn := quicwire.ConsumeVarintBytes(pkt[n:])
if nn < 0 {
return -1
}
@@ -160,20 +162,20 @@ func parse1RTTPacket(pkt []byte, k *updatingKeyPair, dstConnIDLen int, pnumMax p
func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumber)) (largest packetNumber, ackDelay unscaledAckDelay, n int) {
b := frame[1:] // type
- largestAck, n := consumeVarint(b)
+ largestAck, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
b = b[n:]
- v, n := consumeVarintInt64(b)
+ v, n := quicwire.ConsumeVarintInt64(b)
if n < 0 {
return 0, 0, -1
}
b = b[n:]
ackDelay = unscaledAckDelay(v)
- ackRangeCount, n := consumeVarint(b)
+ ackRangeCount, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
@@ -181,7 +183,7 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe
rangeMax := packetNumber(largestAck)
for i := uint64(0); ; i++ {
- rangeLen, n := consumeVarint(b)
+ rangeLen, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
@@ -196,7 +198,7 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe
break
}
- gap, n := consumeVarint(b)
+ gap, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
@@ -209,17 +211,17 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe
return packetNumber(largestAck), ackDelay, len(frame) - len(b)
}
- ect0Count, n := consumeVarint(b)
+ ect0Count, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
b = b[n:]
- ect1Count, n := consumeVarint(b)
+ ect1Count, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
b = b[n:]
- ecnCECount, n := consumeVarint(b)
+ ecnCECount, n := quicwire.ConsumeVarint(b)
if n < 0 {
return 0, 0, -1
}
@@ -236,17 +238,17 @@ func consumeAckFrame(frame []byte, f func(rangeIndex int, start, end packetNumbe
func consumeResetStreamFrame(b []byte) (id streamID, code uint64, finalSize int64, n int) {
n = 1
- idInt, nn := consumeVarint(b[n:])
+ idInt, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, 0, -1
}
n += nn
- code, nn = consumeVarint(b[n:])
+ code, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, 0, -1
}
n += nn
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, 0, -1
}
@@ -257,12 +259,12 @@ func consumeResetStreamFrame(b []byte) (id streamID, code uint64, finalSize int6
func consumeStopSendingFrame(b []byte) (id streamID, code uint64, n int) {
n = 1
- idInt, nn := consumeVarint(b[n:])
+ idInt, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
n += nn
- code, nn = consumeVarint(b[n:])
+ code, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -272,13 +274,13 @@ func consumeStopSendingFrame(b []byte) (id streamID, code uint64, n int) {
func consumeCryptoFrame(b []byte) (off int64, data []byte, n int) {
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, nil, -1
}
off = int64(v)
n += nn
- data, nn = consumeVarintBytes(b[n:])
+ data, nn = quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, nil, -1
}
@@ -288,7 +290,7 @@ func consumeCryptoFrame(b []byte) (off int64, data []byte, n int) {
func consumeNewTokenFrame(b []byte) (token []byte, n int) {
n = 1
- data, nn := consumeVarintBytes(b[n:])
+ data, nn := quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return nil, -1
}
@@ -302,13 +304,13 @@ func consumeNewTokenFrame(b []byte) (token []byte, n int) {
func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte, n int) {
fin = (b[0] & 0x01) != 0
n = 1
- idInt, nn := consumeVarint(b[n:])
+ idInt, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, false, nil, -1
}
n += nn
if b[0]&0x04 != 0 {
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, false, nil, -1
}
@@ -316,7 +318,7 @@ func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte
off = int64(v)
}
if b[0]&0x02 != 0 {
- data, nn = consumeVarintBytes(b[n:])
+ data, nn = quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, 0, false, nil, -1
}
@@ -333,7 +335,7 @@ func consumeStreamFrame(b []byte) (id streamID, off int64, fin bool, data []byte
func consumeMaxDataFrame(b []byte) (max int64, n int) {
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, -1
}
@@ -343,13 +345,13 @@ func consumeMaxDataFrame(b []byte) (max int64, n int) {
func consumeMaxStreamDataFrame(b []byte) (id streamID, max int64, n int) {
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
n += nn
id = streamID(v)
- v, nn = consumeVarint(b[n:])
+ v, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -368,7 +370,7 @@ func consumeMaxStreamsFrame(b []byte) (typ streamType, max int64, n int) {
return 0, 0, -1
}
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -381,13 +383,13 @@ func consumeMaxStreamsFrame(b []byte) (typ streamType, max int64, n int) {
func consumeStreamDataBlockedFrame(b []byte) (id streamID, max int64, n int) {
n = 1
- v, nn := consumeVarint(b[n:])
+ v, nn := quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, -1
}
n += nn
id = streamID(v)
- max, nn = consumeVarintInt64(b[n:])
+ max, nn = quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -397,7 +399,7 @@ func consumeStreamDataBlockedFrame(b []byte) (id streamID, max int64, n int) {
func consumeDataBlockedFrame(b []byte) (max int64, n int) {
n = 1
- max, nn := consumeVarintInt64(b[n:])
+ max, nn := quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, -1
}
@@ -412,7 +414,7 @@ func consumeStreamsBlockedFrame(b []byte) (typ streamType, max int64, n int) {
typ = uniStream
}
n = 1
- max, nn := consumeVarintInt64(b[n:])
+ max, nn := quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, 0, -1
}
@@ -423,12 +425,12 @@ func consumeStreamsBlockedFrame(b []byte) (typ streamType, max int64, n int) {
func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, resetToken statelessResetToken, n int) {
n = 1
var nn int
- seq, nn = consumeVarintInt64(b[n:])
+ seq, nn = quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, 0, nil, statelessResetToken{}, -1
}
n += nn
- retire, nn = consumeVarintInt64(b[n:])
+ retire, nn = quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, 0, nil, statelessResetToken{}, -1
}
@@ -436,7 +438,7 @@ func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, re
if seq < retire {
return 0, 0, nil, statelessResetToken{}, -1
}
- connID, nn = consumeVarintBytes(b[n:])
+ connID, nn = quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, 0, nil, statelessResetToken{}, -1
}
@@ -455,7 +457,7 @@ func consumeNewConnectionIDFrame(b []byte) (seq, retire int64, connID []byte, re
func consumeRetireConnectionIDFrame(b []byte) (seq int64, n int) {
n = 1
var nn int
- seq, nn = consumeVarintInt64(b[n:])
+ seq, nn = quicwire.ConsumeVarintInt64(b[n:])
if nn < 0 {
return 0, -1
}
@@ -481,18 +483,18 @@ func consumeConnectionCloseTransportFrame(b []byte) (code transportError, frameT
n = 1
var nn int
var codeInt uint64
- codeInt, nn = consumeVarint(b[n:])
+ codeInt, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, "", -1
}
code = transportError(codeInt)
n += nn
- frameType, nn = consumeVarint(b[n:])
+ frameType, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, 0, "", -1
}
n += nn
- reasonb, nn := consumeVarintBytes(b[n:])
+ reasonb, nn := quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, 0, "", -1
}
@@ -504,12 +506,12 @@ func consumeConnectionCloseTransportFrame(b []byte) (code transportError, frameT
func consumeConnectionCloseApplicationFrame(b []byte) (code uint64, reason string, n int) {
n = 1
var nn int
- code, nn = consumeVarint(b[n:])
+ code, nn = quicwire.ConsumeVarint(b[n:])
if nn < 0 {
return 0, "", -1
}
n += nn
- reasonb, nn := consumeVarintBytes(b[n:])
+ reasonb, nn := quicwire.ConsumeVarintBytes(b[n:])
if nn < 0 {
return 0, "", -1
}
diff --git a/quic/packet_protection.go b/quic/packet_protection.go
index fe48c14c5d..9f1bbc6a4a 100644
--- a/quic/packet_protection.go
+++ b/quic/packet_protection.go
@@ -519,7 +519,7 @@ func hashForSuite(suite uint16) (h crypto.Hash, keySize int) {
}
}
-// hdkfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
+// hkdfExpandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1.
//
// Copied from crypto/tls/key_schedule.go.
func hkdfExpandLabel(hash func() hash.Hash, secret []byte, label string, context []byte, length int) []byte {
diff --git a/quic/packet_writer.go b/quic/packet_writer.go
index e4d71e622b..e75edcda5b 100644
--- a/quic/packet_writer.go
+++ b/quic/packet_writer.go
@@ -8,6 +8,8 @@ package quic
import (
"encoding/binary"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// A packetWriter constructs QUIC datagrams.
@@ -47,7 +49,7 @@ func (w *packetWriter) datagram() []byte {
return w.b
}
-// packet returns the size of the current packet.
+// packetLen returns the size of the current packet.
func (w *packetWriter) packetLen() int {
return len(w.b[w.pktOff:]) + aeadOverhead
}
@@ -74,7 +76,7 @@ func (w *packetWriter) startProtectedLongHeaderPacket(pnumMaxAcked packetNumber,
hdrSize += 1 + len(p.srcConnID)
switch p.ptype {
case packetTypeInitial:
- hdrSize += sizeVarint(uint64(len(p.extra))) + len(p.extra)
+ hdrSize += quicwire.SizeVarint(uint64(len(p.extra))) + len(p.extra)
}
hdrSize += 2 // length, hardcoded to a 2-byte varint
pnumOff := len(w.b) + hdrSize
@@ -127,11 +129,11 @@ func (w *packetWriter) finishProtectedLongHeaderPacket(pnumMaxAcked packetNumber
}
hdr = append(hdr, headerFormLong|fixedBit|typeBits|byte(pnumLen-1))
hdr = binary.BigEndian.AppendUint32(hdr, p.version)
- hdr = appendUint8Bytes(hdr, p.dstConnID)
- hdr = appendUint8Bytes(hdr, p.srcConnID)
+ hdr = quicwire.AppendUint8Bytes(hdr, p.dstConnID)
+ hdr = quicwire.AppendUint8Bytes(hdr, p.srcConnID)
switch p.ptype {
case packetTypeInitial:
- hdr = appendVarintBytes(hdr, p.extra) // token
+ hdr = quicwire.AppendVarintBytes(hdr, p.extra) // token
}
// Packet length, always encoded as a 2-byte varint.
@@ -270,26 +272,26 @@ func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscale
largest = uint64(seen.max())
firstRange = uint64(seen[len(seen)-1].size() - 1)
)
- if w.avail() < 1+sizeVarint(largest)+sizeVarint(uint64(delay))+1+sizeVarint(firstRange) {
+ if w.avail() < 1+quicwire.SizeVarint(largest)+quicwire.SizeVarint(uint64(delay))+1+quicwire.SizeVarint(firstRange) {
return false
}
w.b = append(w.b, frameTypeAck)
- w.b = appendVarint(w.b, largest)
- w.b = appendVarint(w.b, uint64(delay))
+ w.b = quicwire.AppendVarint(w.b, largest)
+ w.b = quicwire.AppendVarint(w.b, uint64(delay))
// The range count is technically a varint, but we'll reserve a single byte for it
// and never add more than 62 ranges (the maximum varint that fits in a byte).
rangeCountOff := len(w.b)
w.b = append(w.b, 0)
- w.b = appendVarint(w.b, firstRange)
+ w.b = quicwire.AppendVarint(w.b, firstRange)
rangeCount := byte(0)
for i := len(seen) - 2; i >= 0; i-- {
gap := uint64(seen[i+1].start - seen[i].end - 1)
size := uint64(seen[i].size() - 1)
- if w.avail() < sizeVarint(gap)+sizeVarint(size) || rangeCount > 62 {
+ if w.avail() < quicwire.SizeVarint(gap)+quicwire.SizeVarint(size) || rangeCount > 62 {
break
}
- w.b = appendVarint(w.b, gap)
- w.b = appendVarint(w.b, size)
+ w.b = quicwire.AppendVarint(w.b, gap)
+ w.b = quicwire.AppendVarint(w.b, size)
rangeCount++
}
w.b[rangeCountOff] = rangeCount
@@ -299,34 +301,34 @@ func (w *packetWriter) appendAckFrame(seen rangeset[packetNumber], delay unscale
}
func (w *packetWriter) appendNewTokenFrame(token []byte) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(len(token)))+len(token) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(len(token)))+len(token) {
return false
}
w.b = append(w.b, frameTypeNewToken)
- w.b = appendVarintBytes(w.b, token)
+ w.b = quicwire.AppendVarintBytes(w.b, token)
return true
}
func (w *packetWriter) appendResetStreamFrame(id streamID, code uint64, finalSize int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(code)+sizeVarint(uint64(finalSize)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(code)+quicwire.SizeVarint(uint64(finalSize)) {
return false
}
w.b = append(w.b, frameTypeResetStream)
- w.b = appendVarint(w.b, uint64(id))
- w.b = appendVarint(w.b, code)
- w.b = appendVarint(w.b, uint64(finalSize))
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, code)
+ w.b = quicwire.AppendVarint(w.b, uint64(finalSize))
w.sent.appendAckElicitingFrame(frameTypeResetStream)
w.sent.appendInt(uint64(id))
return true
}
func (w *packetWriter) appendStopSendingFrame(id streamID, code uint64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(code) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(code) {
return false
}
w.b = append(w.b, frameTypeStopSending)
- w.b = appendVarint(w.b, uint64(id))
- w.b = appendVarint(w.b, code)
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, code)
w.sent.appendAckElicitingFrame(frameTypeStopSending)
w.sent.appendInt(uint64(id))
return true
@@ -337,9 +339,9 @@ func (w *packetWriter) appendStopSendingFrame(id streamID, code uint64) (added b
// The returned []byte may be smaller than size if the packet cannot hold all the data.
func (w *packetWriter) appendCryptoFrame(off int64, size int) (_ []byte, added bool) {
max := w.avail()
- max -= 1 // frame type
- max -= sizeVarint(uint64(off)) // offset
- max -= sizeVarint(uint64(size)) // maximum length
+ max -= 1 // frame type
+ max -= quicwire.SizeVarint(uint64(off)) // offset
+ max -= quicwire.SizeVarint(uint64(size)) // maximum length
if max <= 0 {
return nil, false
}
@@ -347,8 +349,8 @@ func (w *packetWriter) appendCryptoFrame(off int64, size int) (_ []byte, added b
size = max
}
w.b = append(w.b, frameTypeCrypto)
- w.b = appendVarint(w.b, uint64(off))
- w.b = appendVarint(w.b, uint64(size))
+ w.b = quicwire.AppendVarint(w.b, uint64(off))
+ w.b = quicwire.AppendVarint(w.b, uint64(size))
start := len(w.b)
w.b = w.b[:start+size]
w.sent.appendAckElicitingFrame(frameTypeCrypto)
@@ -363,12 +365,12 @@ func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin b
typ := uint8(frameTypeStreamBase | streamLenBit)
max := w.avail()
max -= 1 // frame type
- max -= sizeVarint(uint64(id))
+ max -= quicwire.SizeVarint(uint64(id))
if off != 0 {
- max -= sizeVarint(uint64(off))
+ max -= quicwire.SizeVarint(uint64(off))
typ |= streamOffBit
}
- max -= sizeVarint(uint64(size)) // maximum length
+ max -= quicwire.SizeVarint(uint64(size)) // maximum length
if max < 0 || (max == 0 && size > 0) {
return nil, false
}
@@ -378,11 +380,11 @@ func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin b
typ |= streamFinBit
}
w.b = append(w.b, typ)
- w.b = appendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
if off != 0 {
- w.b = appendVarint(w.b, uint64(off))
+ w.b = quicwire.AppendVarint(w.b, uint64(off))
}
- w.b = appendVarint(w.b, uint64(size))
+ w.b = quicwire.AppendVarint(w.b, uint64(size))
start := len(w.b)
w.b = w.b[:start+size]
w.sent.appendAckElicitingFrame(typ & (frameTypeStreamBase | streamFinBit))
@@ -392,29 +394,29 @@ func (w *packetWriter) appendStreamFrame(id streamID, off int64, size int, fin b
}
func (w *packetWriter) appendMaxDataFrame(max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(max)) {
return false
}
w.b = append(w.b, frameTypeMaxData)
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(frameTypeMaxData)
return true
}
func (w *packetWriter) appendMaxStreamDataFrame(id streamID, max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(uint64(max)) {
return false
}
w.b = append(w.b, frameTypeMaxStreamData)
- w.b = appendVarint(w.b, uint64(id))
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(frameTypeMaxStreamData)
w.sent.appendInt(uint64(id))
return true
}
func (w *packetWriter) appendMaxStreamsFrame(streamType streamType, max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(max)) {
return false
}
var typ byte
@@ -424,35 +426,35 @@ func (w *packetWriter) appendMaxStreamsFrame(streamType streamType, max int64) (
typ = frameTypeMaxStreamsUni
}
w.b = append(w.b, typ)
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(typ)
return true
}
func (w *packetWriter) appendDataBlockedFrame(max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(max)) {
return false
}
w.b = append(w.b, frameTypeDataBlocked)
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(frameTypeDataBlocked)
return true
}
func (w *packetWriter) appendStreamDataBlockedFrame(id streamID, max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(id))+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(id))+quicwire.SizeVarint(uint64(max)) {
return false
}
w.b = append(w.b, frameTypeStreamDataBlocked)
- w.b = appendVarint(w.b, uint64(id))
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(id))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(frameTypeStreamDataBlocked)
w.sent.appendInt(uint64(id))
return true
}
func (w *packetWriter) appendStreamsBlockedFrame(typ streamType, max int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(max)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(max)) {
return false
}
var ftype byte
@@ -462,19 +464,19 @@ func (w *packetWriter) appendStreamsBlockedFrame(typ streamType, max int64) (add
ftype = frameTypeStreamsBlockedUni
}
w.b = append(w.b, ftype)
- w.b = appendVarint(w.b, uint64(max))
+ w.b = quicwire.AppendVarint(w.b, uint64(max))
w.sent.appendAckElicitingFrame(ftype)
return true
}
func (w *packetWriter) appendNewConnectionIDFrame(seq, retirePriorTo int64, connID []byte, token [16]byte) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(seq))+sizeVarint(uint64(retirePriorTo))+1+len(connID)+len(token) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(seq))+quicwire.SizeVarint(uint64(retirePriorTo))+1+len(connID)+len(token) {
return false
}
w.b = append(w.b, frameTypeNewConnectionID)
- w.b = appendVarint(w.b, uint64(seq))
- w.b = appendVarint(w.b, uint64(retirePriorTo))
- w.b = appendUint8Bytes(w.b, connID)
+ w.b = quicwire.AppendVarint(w.b, uint64(seq))
+ w.b = quicwire.AppendVarint(w.b, uint64(retirePriorTo))
+ w.b = quicwire.AppendUint8Bytes(w.b, connID)
w.b = append(w.b, token[:]...)
w.sent.appendAckElicitingFrame(frameTypeNewConnectionID)
w.sent.appendInt(uint64(seq))
@@ -482,11 +484,11 @@ func (w *packetWriter) appendNewConnectionIDFrame(seq, retirePriorTo int64, conn
}
func (w *packetWriter) appendRetireConnectionIDFrame(seq int64) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(seq)) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(seq)) {
return false
}
w.b = append(w.b, frameTypeRetireConnectionID)
- w.b = appendVarint(w.b, uint64(seq))
+ w.b = quicwire.AppendVarint(w.b, uint64(seq))
w.sent.appendAckElicitingFrame(frameTypeRetireConnectionID)
w.sent.appendInt(uint64(seq))
return true
@@ -515,27 +517,27 @@ func (w *packetWriter) appendPathResponseFrame(data pathChallengeData) (added bo
// appendConnectionCloseTransportFrame appends a CONNECTION_CLOSE frame
// carrying a transport error code.
func (w *packetWriter) appendConnectionCloseTransportFrame(code transportError, frameType uint64, reason string) (added bool) {
- if w.avail() < 1+sizeVarint(uint64(code))+sizeVarint(frameType)+sizeVarint(uint64(len(reason)))+len(reason) {
+ if w.avail() < 1+quicwire.SizeVarint(uint64(code))+quicwire.SizeVarint(frameType)+quicwire.SizeVarint(uint64(len(reason)))+len(reason) {
return false
}
w.b = append(w.b, frameTypeConnectionCloseTransport)
- w.b = appendVarint(w.b, uint64(code))
- w.b = appendVarint(w.b, frameType)
- w.b = appendVarintBytes(w.b, []byte(reason))
+ w.b = quicwire.AppendVarint(w.b, uint64(code))
+ w.b = quicwire.AppendVarint(w.b, frameType)
+ w.b = quicwire.AppendVarintBytes(w.b, []byte(reason))
// We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or
// detected as lost.
return true
}
-// appendConnectionCloseTransportFrame appends a CONNECTION_CLOSE frame
+// appendConnectionCloseApplicationFrame appends a CONNECTION_CLOSE frame
// carrying an application protocol error code.
func (w *packetWriter) appendConnectionCloseApplicationFrame(code uint64, reason string) (added bool) {
- if w.avail() < 1+sizeVarint(code)+sizeVarint(uint64(len(reason)))+len(reason) {
+ if w.avail() < 1+quicwire.SizeVarint(code)+quicwire.SizeVarint(uint64(len(reason)))+len(reason) {
return false
}
w.b = append(w.b, frameTypeConnectionCloseApplication)
- w.b = appendVarint(w.b, code)
- w.b = appendVarintBytes(w.b, []byte(reason))
+ w.b = quicwire.AppendVarint(w.b, code)
+ w.b = quicwire.AppendVarintBytes(w.b, []byte(reason))
// We don't record CONNECTION_CLOSE frames in w.sent, since they are never acked or
// detected as lost.
return true
diff --git a/quic/qlog/json_writer.go b/quic/qlog/json_writer.go
index 6fb8d33b25..7867c590df 100644
--- a/quic/qlog/json_writer.go
+++ b/quic/qlog/json_writer.go
@@ -58,7 +58,7 @@ func (w *jsonWriter) writeAttr(a slog.Attr) {
w.writeValue(a.Value)
}
-// writeAttr writes a []slog.Attr as an object field.
+// writeAttrsField writes a []slog.Attr as an object field.
func (w *jsonWriter) writeAttrsField(name string, attrs []slog.Attr) {
w.writeName(name)
w.writeAttrs(attrs)
@@ -113,7 +113,7 @@ func (w *jsonWriter) writeObject(f func()) {
w.buf.WriteByte('}')
}
-// writeObject writes an object-valued object field.
+// writeObjectField writes an object-valued object field.
// The function f is called to write the contents.
func (w *jsonWriter) writeObjectField(name string, f func()) {
w.writeName(name)
diff --git a/quic/rangeset.go b/quic/rangeset.go
index b8b2e93672..528d53df39 100644
--- a/quic/rangeset.go
+++ b/quic/rangeset.go
@@ -159,6 +159,14 @@ func (s rangeset[T]) numRanges() int {
return len(s)
}
+// size returns the size of all ranges in the rangeset.
+func (s rangeset[T]) size() (total T) {
+ for _, r := range s {
+ total += r.size()
+ }
+ return total
+}
+
// isrange reports if the rangeset covers exactly the range [start, end).
func (s rangeset[T]) isrange(start, end T) bool {
switch len(s) {
diff --git a/quic/retry.go b/quic/retry.go
index 5dc39d1d9d..8c56ee1b10 100644
--- a/quic/retry.go
+++ b/quic/retry.go
@@ -16,6 +16,7 @@ import (
"time"
"golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/net/internal/quic/quicwire"
)
// AEAD and nonce used to compute the Retry Integrity Tag.
@@ -133,7 +134,7 @@ func (rs *retryState) validateToken(now time.Time, token, srcConnID, dstConnID [
func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []byte {
var additional []byte
- additional = appendUint8Bytes(additional, srcConnID)
+ additional = quicwire.AppendUint8Bytes(additional, srcConnID)
additional = append(additional, addr.Addr().AsSlice()...)
additional = binary.BigEndian.AppendUint16(additional, addr.Port())
return additional
@@ -141,7 +142,7 @@ func (rs *retryState) additionalData(srcConnID []byte, addr netip.AddrPort) []by
func (e *Endpoint) validateInitialAddress(now time.Time, p genericLongPacket, peerAddr netip.AddrPort) (origDstConnID []byte, ok bool) {
// The retry token is at the start of an Initial packet's data.
- token, n := consumeUint8Bytes(p.data)
+ token, n := quicwire.ConsumeUint8Bytes(p.data)
if n < 0 {
// We've already validated that the packet is at least 1200 bytes long,
// so there's no way for even a maximum size token to not fit.
@@ -196,12 +197,12 @@ func encodeRetryPacket(originalDstConnID []byte, p retryPacket) []byte {
// Create the pseudo-packet (including the original DCID), append the tag,
// and return the Retry packet.
var b []byte
- b = appendUint8Bytes(b, originalDstConnID) // Original Destination Connection ID
- start := len(b) // start of the Retry packet
+ b = quicwire.AppendUint8Bytes(b, originalDstConnID) // Original Destination Connection ID
+ start := len(b) // start of the Retry packet
b = append(b, headerFormLong|fixedBit|longPacketTypeRetry)
b = binary.BigEndian.AppendUint32(b, quicVersion1) // Version
- b = appendUint8Bytes(b, p.dstConnID) // Destination Connection ID
- b = appendUint8Bytes(b, p.srcConnID) // Source Connection ID
+ b = quicwire.AppendUint8Bytes(b, p.dstConnID) // Destination Connection ID
+ b = quicwire.AppendUint8Bytes(b, p.srcConnID) // Source Connection ID
b = append(b, p.token...) // Token
b = retryAEAD.Seal(b, retryNonce, nil, b) // Retry Integrity Tag
return b[start:]
@@ -222,7 +223,7 @@ func parseRetryPacket(b, origDstConnID []byte) (p retryPacket, ok bool) {
// Create the pseudo-packet consisting of the original destination connection ID
// followed by the Retry packet (less the integrity tag).
// Use this to validate the packet integrity tag.
- pseudo := appendUint8Bytes(nil, origDstConnID)
+ pseudo := quicwire.AppendUint8Bytes(nil, origDstConnID)
pseudo = append(pseudo, b[:len(b)-retryIntegrityTagLength]...)
wantTag := retryAEAD.Seal(nil, retryNonce, nil, pseudo)
if !bytes.Equal(gotTag, wantTag) {
diff --git a/quic/rtt.go b/quic/rtt.go
index 4942f8cca1..494060c67d 100644
--- a/quic/rtt.go
+++ b/quic/rtt.go
@@ -37,7 +37,7 @@ func (r *rttState) establishPersistentCongestion() {
r.minRTT = r.latestRTT
}
-// updateRTTSample is called when we generate a new RTT sample.
+// updateSample is called when we generate a new RTT sample.
// https://www.rfc-editor.org/rfc/rfc9002.html#section-5
func (r *rttState) updateSample(now time.Time, handshakeConfirmed bool, spaceID numberSpace, latestRTT, ackDelay, maxAckDelay time.Duration) {
r.latestRTT = latestRTT
diff --git a/quic/sent_packet.go b/quic/sent_packet.go
index 226152327d..eedd2f61b3 100644
--- a/quic/sent_packet.go
+++ b/quic/sent_packet.go
@@ -9,6 +9,8 @@ package quic
import (
"sync"
"time"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// A sentPacket tracks state related to an in-flight packet we sent,
@@ -78,12 +80,12 @@ func (sent *sentPacket) appendAckElicitingFrame(frameType byte) {
}
func (sent *sentPacket) appendInt(v uint64) {
- sent.b = appendVarint(sent.b, v)
+ sent.b = quicwire.AppendVarint(sent.b, v)
}
func (sent *sentPacket) appendOffAndSize(start int64, size int) {
- sent.b = appendVarint(sent.b, uint64(start))
- sent.b = appendVarint(sent.b, uint64(size))
+ sent.b = quicwire.AppendVarint(sent.b, uint64(start))
+ sent.b = quicwire.AppendVarint(sent.b, uint64(size))
}
// The next* methods read back information about frames in the packet.
@@ -95,7 +97,7 @@ func (sent *sentPacket) next() (frameType byte) {
}
func (sent *sentPacket) nextInt() uint64 {
- v, n := consumeVarint(sent.b[sent.n:])
+ v, n := quicwire.ConsumeVarint(sent.b[sent.n:])
sent.n += n
return v
}
diff --git a/quic/sent_val.go b/quic/sent_val.go
index 31f69e47d0..920658919b 100644
--- a/quic/sent_val.go
+++ b/quic/sent_val.go
@@ -37,7 +37,7 @@ func (s sentVal) isSet() bool { return s != 0 }
// shouldSend reports whether the value is set and has not been sent to the peer.
func (s sentVal) shouldSend() bool { return s.state() == sentValUnsent }
-// shouldSend reports whether the value needs to be sent to the peer.
+// shouldSendPTO reports whether the value needs to be sent to the peer.
// The value needs to be sent if it is set and has not been sent.
// If pto is true, indicating that we are sending a PTO probe, the value
// should also be sent if it is set and has not been acknowledged.
diff --git a/quic/stream.go b/quic/stream.go
index cb45534f82..8068b10acd 100644
--- a/quic/stream.go
+++ b/quic/stream.go
@@ -12,6 +12,8 @@ import (
"fmt"
"io"
"math"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// A Stream is an ordered byte stream.
@@ -254,6 +256,11 @@ func (s *Stream) Read(b []byte) (n int, err error) {
s.conn.handleStreamBytesReadOffLoop(bytesRead) // must be done with ingate unlocked
}()
if s.inresetcode != -1 {
+ if s.inresetcode == streamResetByConnClose {
+ if err := s.conn.finalError(); err != nil {
+ return 0, err
+ }
+ }
return 0, fmt.Errorf("stream reset by peer: %w", StreamErrorCode(s.inresetcode))
}
if s.inclosed.isSet() {
@@ -308,7 +315,7 @@ func (s *Stream) ReadByte() (byte, error) {
var b [1]byte
n, err := s.Read(b[:])
if n > 0 {
- return b[0], err
+ return b[0], nil
}
return 0, err
}
@@ -352,13 +359,9 @@ func (s *Stream) Write(b []byte) (n int, err error) {
// write blocked. (Unlike traditional condition variables, gates do not
// have spurious wakeups.)
}
- if s.outreset.isSet() {
- s.outUnlock()
- return n, errors.New("write to reset stream")
- }
- if s.outclosed.isSet() {
+ if err := s.writeErrorLocked(); err != nil {
s.outUnlock()
- return n, errors.New("write to closed stream")
+ return n, err
}
if len(b) == 0 {
break
@@ -418,7 +421,7 @@ func (s *Stream) Write(b []byte) (n int, err error) {
return n, nil
}
-// WriteBytes writes a single byte to the stream.
+// WriteByte writes a single byte to the stream.
func (s *Stream) WriteByte(c byte) error {
if s.outbufoff < len(s.outbuf) {
s.outbuf[s.outbufoff] = c
@@ -445,10 +448,34 @@ func (s *Stream) flushFastOutputBuffer() {
// Flush flushes data written to the stream.
// It does not wait for the peer to acknowledge receipt of the data.
// Use Close to wait for the peer's acknowledgement.
-func (s *Stream) Flush() {
+func (s *Stream) Flush() error {
+ if s.IsReadOnly() {
+ return errors.New("flush of read-only stream")
+ }
s.outgate.lock()
defer s.outUnlock()
+ if err := s.writeErrorLocked(); err != nil {
+ return err
+ }
s.flushLocked()
+ return nil
+}
+
+// writeErrorLocked returns the error (if any) which should be returned by write operations
+// due to the stream being reset or closed.
+func (s *Stream) writeErrorLocked() error {
+ if s.outreset.isSet() {
+ if s.outresetcode == streamResetByConnClose {
+ if err := s.conn.finalError(); err != nil {
+ return err
+ }
+ }
+ return errors.New("write to reset stream")
+ }
+ if s.outclosed.isSet() {
+ return errors.New("write to closed stream")
+ }
+ return nil
}
func (s *Stream) flushLocked() {
@@ -560,8 +587,8 @@ func (s *Stream) resetInternal(code uint64, userClosed bool) {
if s.outreset.isSet() {
return
}
- if code > maxVarint {
- code = maxVarint
+ if code > quicwire.MaxVarint {
+ code = quicwire.MaxVarint
}
// We could check here to see if the stream is closed and the
// peer has acked all the data and the FIN, but sending an
@@ -595,8 +622,11 @@ func (s *Stream) connHasClosed() {
s.outgate.lock()
if localClose {
s.outclosed.set()
+ s.outreset.set()
+ } else {
+ s.outresetcode = streamResetByConnClose
+ s.outreset.setReceived()
}
- s.outreset.set()
s.outUnlock()
}
diff --git a/quic/stream_test.go b/quic/stream_test.go
index 9f857f29d4..2643ae3dba 100644
--- a/quic/stream_test.go
+++ b/quic/stream_test.go
@@ -15,6 +15,8 @@ import (
"io"
"strings"
"testing"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
func TestStreamWriteBlockedByOutputBuffer(t *testing.T) {
@@ -566,6 +568,25 @@ func TestStreamReceiveEmptyEOF(t *testing.T) {
})
}
+func TestStreamReadByteFromOneByteStream(t *testing.T) {
+ // ReadByte on the only byte of a stream should not return an error.
+ testStreamTypes(t, "", func(t *testing.T, styp streamType) {
+ tc, s := newTestConnAndRemoteStream(t, serverSide, styp, permissiveTransportParameters)
+ want := byte(1)
+ tc.writeFrames(packetType1RTT, debugFrameStream{
+ id: s.id,
+ data: []byte{want},
+ fin: true,
+ })
+ if got, err := s.ReadByte(); got != want || err != nil {
+ t.Fatalf("s.ReadByte() = %v, %v; want %v, nil", got, err, want)
+ }
+ if got, err := s.ReadByte(); err != io.EOF {
+ t.Fatalf("s.ReadByte() = %v, %v; want _, EOF", got, err)
+ }
+ })
+}
+
func finalSizeTest(t *testing.T, wantErr transportError, f func(tc *testConn, sid streamID) (finalSize int64), opts ...any) {
testStreamTypes(t, "", func(t *testing.T, styp streamType) {
for _, test := range []struct {
@@ -1324,6 +1345,61 @@ func TestStreamFlushExplicit(t *testing.T) {
})
}
+func TestStreamFlushClosedStream(t *testing.T) {
+ _, s := newTestConnAndLocalStream(t, clientSide, bidiStream,
+ permissiveTransportParameters)
+ s.Close()
+ if err := s.Flush(); err == nil {
+ t.Errorf("s.Flush of closed stream = nil, want error")
+ }
+}
+
+func TestStreamFlushResetStream(t *testing.T) {
+ _, s := newTestConnAndLocalStream(t, clientSide, bidiStream,
+ permissiveTransportParameters)
+ s.Reset(0)
+ if err := s.Flush(); err == nil {
+ t.Errorf("s.Flush of reset stream = nil, want error")
+ }
+}
+
+func TestStreamFlushStreamAfterPeerStopSending(t *testing.T) {
+ tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream,
+ permissiveTransportParameters)
+ s.Flush() // create the stream
+ tc.wantFrame("stream created after flush",
+ packetType1RTT, debugFrameStream{
+ id: s.id,
+ data: []byte{},
+ })
+
+ // Peer sends a STOP_SENDING.
+ tc.writeFrames(packetType1RTT, debugFrameStopSending{
+ id: s.id,
+ })
+ if err := s.Flush(); err == nil {
+ t.Errorf("s.Flush of stream reset by peer = nil, want error")
+ }
+}
+
+func TestStreamErrorsAfterConnectionClosed(t *testing.T) {
+ tc, s := newTestConnAndLocalStream(t, clientSide, bidiStream,
+ permissiveTransportParameters)
+ wantErr := &ApplicationError{Code: 42}
+ tc.writeFrames(packetType1RTT, debugFrameConnectionCloseApplication{
+ code: wantErr.Code,
+ })
+ if _, err := s.Read(make([]byte, 1)); !errors.Is(err, wantErr) {
+ t.Errorf("s.Read on closed connection = %v, want %v", err, wantErr)
+ }
+ if _, err := s.Write(make([]byte, 1)); !errors.Is(err, wantErr) {
+ t.Errorf("s.Write on closed connection = %v, want %v", err, wantErr)
+ }
+ if err := s.Flush(); !errors.Is(err, wantErr) {
+ t.Errorf("s.Flush on closed connection = %v, want %v", err, wantErr)
+ }
+}
+
func TestStreamFlushImplicitExact(t *testing.T) {
testStreamTypes(t, "", func(t *testing.T, styp streamType) {
const writeBufferSize = 4
@@ -1467,10 +1543,10 @@ func newRemoteStream(t *testing.T, tc *testConn, styp streamType) *Stream {
func permissiveTransportParameters(p *transportParameters) {
p.initialMaxStreamsBidi = maxStreamsLimit
p.initialMaxStreamsUni = maxStreamsLimit
- p.initialMaxData = maxVarint
- p.initialMaxStreamDataBidiRemote = maxVarint
- p.initialMaxStreamDataBidiLocal = maxVarint
- p.initialMaxStreamDataUni = maxVarint
+ p.initialMaxData = quicwire.MaxVarint
+ p.initialMaxStreamDataBidiRemote = quicwire.MaxVarint
+ p.initialMaxStreamDataBidiLocal = quicwire.MaxVarint
+ p.initialMaxStreamDataUni = quicwire.MaxVarint
}
func makeTestData(n int) []byte {
diff --git a/quic/tls.go b/quic/tls.go
index e2f2e5bde1..89b31842cd 100644
--- a/quic/tls.go
+++ b/quic/tls.go
@@ -119,11 +119,7 @@ func (c *Conn) handleCrypto(now time.Time, space numberSpace, off int64, data []
default:
return errors.New("quic: internal error: received CRYPTO frame in unexpected number space")
}
- err := c.crypto[space].handleCrypto(off, data, func(b []byte) error {
+ return c.crypto[space].handleCrypto(off, data, func(b []byte) error {
return c.tls.HandleData(level, b)
})
- if err != nil {
- return err
- }
- return c.handleTLSEvents(now)
}
diff --git a/quic/tls_test.go b/quic/tls_test.go
index 9c1dd364ec..f4abdda582 100644
--- a/quic/tls_test.go
+++ b/quic/tls_test.go
@@ -615,3 +615,32 @@ func TestConnAEADLimitReached(t *testing.T) {
tc.advance(1 * time.Second)
tc.wantIdle("auth failures at limit: conn does not process additional packets")
}
+
+func TestConnKeysDiscardedWithExcessCryptoData(t *testing.T) {
+ tc := newTestConn(t, serverSide, permissiveTransportParameters)
+ tc.ignoreFrame(frameTypeAck)
+ tc.ignoreFrame(frameTypeNewConnectionID)
+ tc.ignoreFrame(frameTypeCrypto)
+
+ // One byte of excess CRYPTO data, separated from the valid data by a one-byte gap.
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ off: int64(len(tc.cryptoDataIn[tls.QUICEncryptionLevelInitial]) + 1),
+ data: []byte{0},
+ })
+ tc.writeFrames(packetTypeInitial,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelInitial],
+ })
+
+ // We don't drop the Initial keys and discover the excess data until the client
+ // sends a Handshake packet.
+ tc.writeFrames(packetTypeHandshake,
+ debugFrameCrypto{
+ data: tc.cryptoDataIn[tls.QUICEncryptionLevelHandshake],
+ })
+ tc.wantFrame("connection closed due to excess Initial CRYPTO data",
+ packetType1RTT, debugFrameConnectionCloseTransport{
+ code: errTLSBase + 10,
+ })
+}
diff --git a/quic/tlsconfig_test.go b/quic/tlsconfig_test.go
index 5ed9818d57..e24cef08ae 100644
--- a/quic/tlsconfig_test.go
+++ b/quic/tlsconfig_test.go
@@ -8,7 +8,8 @@ package quic
import (
"crypto/tls"
- "strings"
+
+ "golang.org/x/net/internal/testcert"
)
func newTestTLSConfig(side connSide) *tls.Config {
@@ -47,35 +48,9 @@ func newTestTLSConfigWithMoreDefaults(side connSide) *tls.Config {
}
var testCert = func() tls.Certificate {
- cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+ cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
if err != nil {
panic(err)
}
return cert
}()
-
-// localhostCert is a PEM-encoded TLS cert with SAN IPs
-// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
-// generated from src/crypto/tls:
-// go run generate_cert.go --ecdsa-curve P256 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
-var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIIBrDCCAVKgAwIBAgIPCvPhO+Hfv+NW76kWxULUMAoGCCqGSM49BAMCMBIxEDAO
-BgNVBAoTB0FjbWUgQ28wIBcNNzAwMTAxMDAwMDAwWhgPMjA4NDAxMjkxNjAwMDBa
-MBIxEDAOBgNVBAoTB0FjbWUgQ28wWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARh
-WRF8p8X9scgW7JjqAwI9nYV8jtkdhqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGms
-PyfMPe5Jrha/LmjgR1G9o4GIMIGFMA4GA1UdDwEB/wQEAwIChDATBgNVHSUEDDAK
-BggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBSOJri/wLQxq6oC
-Y6ZImms/STbTljAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAA
-AAAAAAAAAAAAATAKBggqhkjOPQQDAgNIADBFAiBUguxsW6TGhixBAdORmVNnkx40
-HjkKwncMSDbUaeL9jQIhAJwQ8zV9JpQvYpsiDuMmqCuW35XXil3cQ6Drz82c+fvE
------END CERTIFICATE-----`)
-
-// localhostKey is the private key for localhostCert.
-var localhostKey = []byte(testingKey(`-----BEGIN TESTING KEY-----
-MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgY1B1eL/Bbwf/MDcs
-rnvvWhFNr1aGmJJR59PdCN9lVVqhRANCAARhWRF8p8X9scgW7JjqAwI9nYV8jtkd
-hqAXG9gyEgnaFNN5Ze9l3Tp1R9yCDBMNsGmsPyfMPe5Jrha/LmjgR1G9
------END TESTING KEY-----`))
-
-// testingKey helps keep security scanners from getting excited about a private key in this file.
-func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
diff --git a/quic/transport_params.go b/quic/transport_params.go
index 3cc56f4e44..13d1c7c7d5 100644
--- a/quic/transport_params.go
+++ b/quic/transport_params.go
@@ -10,6 +10,8 @@ import (
"encoding/binary"
"net/netip"
"time"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
// transportParameters transferred in the quic_transport_parameters TLS extension.
@@ -77,89 +79,89 @@ const (
func marshalTransportParameters(p transportParameters) []byte {
var b []byte
if v := p.originalDstConnID; v != nil {
- b = appendVarint(b, paramOriginalDestinationConnectionID)
- b = appendVarintBytes(b, v)
+ b = quicwire.AppendVarint(b, paramOriginalDestinationConnectionID)
+ b = quicwire.AppendVarintBytes(b, v)
}
if v := uint64(p.maxIdleTimeout / time.Millisecond); v != 0 {
- b = appendVarint(b, paramMaxIdleTimeout)
- b = appendVarint(b, uint64(sizeVarint(v)))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramMaxIdleTimeout)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(v)))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.statelessResetToken; v != nil {
- b = appendVarint(b, paramStatelessResetToken)
- b = appendVarintBytes(b, v)
+ b = quicwire.AppendVarint(b, paramStatelessResetToken)
+ b = quicwire.AppendVarintBytes(b, v)
}
if v := p.maxUDPPayloadSize; v != defaultParamMaxUDPPayloadSize {
- b = appendVarint(b, paramMaxUDPPayloadSize)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramMaxUDPPayloadSize)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxData; v != 0 {
- b = appendVarint(b, paramInitialMaxData)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxData)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamDataBidiLocal; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamDataBidiLocal)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamDataBidiLocal)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamDataBidiRemote; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamDataBidiRemote)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamDataBidiRemote)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamDataUni; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamDataUni)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamDataUni)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamsBidi; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamsBidi)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamsBidi)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialMaxStreamsUni; v != 0 {
- b = appendVarint(b, paramInitialMaxStreamsUni)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramInitialMaxStreamsUni)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.ackDelayExponent; v != defaultParamAckDelayExponent {
- b = appendVarint(b, paramAckDelayExponent)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramAckDelayExponent)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := uint64(p.maxAckDelay / time.Millisecond); v != defaultParamMaxAckDelayMilliseconds {
- b = appendVarint(b, paramMaxAckDelay)
- b = appendVarint(b, uint64(sizeVarint(v)))
- b = appendVarint(b, v)
+ b = quicwire.AppendVarint(b, paramMaxAckDelay)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(v)))
+ b = quicwire.AppendVarint(b, v)
}
if p.disableActiveMigration {
- b = appendVarint(b, paramDisableActiveMigration)
+ b = quicwire.AppendVarint(b, paramDisableActiveMigration)
b = append(b, 0) // 0-length value
}
if p.preferredAddrConnID != nil {
b = append(b, paramPreferredAddress)
- b = appendVarint(b, uint64(4+2+16+2+1+len(p.preferredAddrConnID)+16))
+ b = quicwire.AppendVarint(b, uint64(4+2+16+2+1+len(p.preferredAddrConnID)+16))
b = append(b, p.preferredAddrV4.Addr().AsSlice()...) // 4 bytes
b = binary.BigEndian.AppendUint16(b, p.preferredAddrV4.Port()) // 2 bytes
b = append(b, p.preferredAddrV6.Addr().AsSlice()...) // 16 bytes
b = binary.BigEndian.AppendUint16(b, p.preferredAddrV6.Port()) // 2 bytes
- b = appendUint8Bytes(b, p.preferredAddrConnID) // 1 byte + len(conn_id)
+ b = quicwire.AppendUint8Bytes(b, p.preferredAddrConnID) // 1 byte + len(conn_id)
b = append(b, p.preferredAddrResetToken...) // 16 bytes
}
if v := p.activeConnIDLimit; v != defaultParamActiveConnIDLimit {
- b = appendVarint(b, paramActiveConnectionIDLimit)
- b = appendVarint(b, uint64(sizeVarint(uint64(v))))
- b = appendVarint(b, uint64(v))
+ b = quicwire.AppendVarint(b, paramActiveConnectionIDLimit)
+ b = quicwire.AppendVarint(b, uint64(quicwire.SizeVarint(uint64(v))))
+ b = quicwire.AppendVarint(b, uint64(v))
}
if v := p.initialSrcConnID; v != nil {
- b = appendVarint(b, paramInitialSourceConnectionID)
- b = appendVarintBytes(b, v)
+ b = quicwire.AppendVarint(b, paramInitialSourceConnectionID)
+ b = quicwire.AppendVarintBytes(b, v)
}
if v := p.retrySrcConnID; v != nil {
- b = appendVarint(b, paramRetrySourceConnectionID)
- b = appendVarintBytes(b, v)
+ b = quicwire.AppendVarint(b, paramRetrySourceConnectionID)
+ b = quicwire.AppendVarintBytes(b, v)
}
return b
}
@@ -167,12 +169,12 @@ func marshalTransportParameters(p transportParameters) []byte {
func unmarshalTransportParams(params []byte) (transportParameters, error) {
p := defaultTransportParameters()
for len(params) > 0 {
- id, n := consumeVarint(params)
+ id, n := quicwire.ConsumeVarint(params)
if n < 0 {
return p, localTransportError{code: errTransportParameter}
}
params = params[n:]
- val, n := consumeVarintBytes(params)
+ val, n := quicwire.ConsumeVarintBytes(params)
if n < 0 {
return p, localTransportError{code: errTransportParameter}
}
@@ -184,7 +186,7 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) {
n = len(val)
case paramMaxIdleTimeout:
var v uint64
- v, n = consumeVarint(val)
+ v, n = quicwire.ConsumeVarint(val)
// If this is unreasonably large, consider it as no timeout to avoid
// time.Duration overflows.
if v > 1<<32 {
@@ -198,38 +200,38 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) {
p.statelessResetToken = val
n = 16
case paramMaxUDPPayloadSize:
- p.maxUDPPayloadSize, n = consumeVarintInt64(val)
+ p.maxUDPPayloadSize, n = quicwire.ConsumeVarintInt64(val)
if p.maxUDPPayloadSize < 1200 {
return p, localTransportError{code: errTransportParameter}
}
case paramInitialMaxData:
- p.initialMaxData, n = consumeVarintInt64(val)
+ p.initialMaxData, n = quicwire.ConsumeVarintInt64(val)
case paramInitialMaxStreamDataBidiLocal:
- p.initialMaxStreamDataBidiLocal, n = consumeVarintInt64(val)
+ p.initialMaxStreamDataBidiLocal, n = quicwire.ConsumeVarintInt64(val)
case paramInitialMaxStreamDataBidiRemote:
- p.initialMaxStreamDataBidiRemote, n = consumeVarintInt64(val)
+ p.initialMaxStreamDataBidiRemote, n = quicwire.ConsumeVarintInt64(val)
case paramInitialMaxStreamDataUni:
- p.initialMaxStreamDataUni, n = consumeVarintInt64(val)
+ p.initialMaxStreamDataUni, n = quicwire.ConsumeVarintInt64(val)
case paramInitialMaxStreamsBidi:
- p.initialMaxStreamsBidi, n = consumeVarintInt64(val)
+ p.initialMaxStreamsBidi, n = quicwire.ConsumeVarintInt64(val)
if p.initialMaxStreamsBidi > maxStreamsLimit {
return p, localTransportError{code: errTransportParameter}
}
case paramInitialMaxStreamsUni:
- p.initialMaxStreamsUni, n = consumeVarintInt64(val)
+ p.initialMaxStreamsUni, n = quicwire.ConsumeVarintInt64(val)
if p.initialMaxStreamsUni > maxStreamsLimit {
return p, localTransportError{code: errTransportParameter}
}
case paramAckDelayExponent:
var v uint64
- v, n = consumeVarint(val)
+ v, n = quicwire.ConsumeVarint(val)
if v > 20 {
return p, localTransportError{code: errTransportParameter}
}
p.ackDelayExponent = int8(v)
case paramMaxAckDelay:
var v uint64
- v, n = consumeVarint(val)
+ v, n = quicwire.ConsumeVarint(val)
if v >= 1<<14 {
return p, localTransportError{code: errTransportParameter}
}
@@ -251,7 +253,7 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) {
)
val = val[16+2:]
var nn int
- p.preferredAddrConnID, nn = consumeUint8Bytes(val)
+ p.preferredAddrConnID, nn = quicwire.ConsumeUint8Bytes(val)
if nn < 0 {
return p, localTransportError{code: errTransportParameter}
}
@@ -262,7 +264,7 @@ func unmarshalTransportParams(params []byte) (transportParameters, error) {
p.preferredAddrResetToken = val
val = nil
case paramActiveConnectionIDLimit:
- p.activeConnIDLimit, n = consumeVarintInt64(val)
+ p.activeConnIDLimit, n = quicwire.ConsumeVarintInt64(val)
if p.activeConnIDLimit < 2 {
return p, localTransportError{code: errTransportParameter}
}
diff --git a/quic/transport_params_test.go b/quic/transport_params_test.go
index cc88e83fd6..f1961178e8 100644
--- a/quic/transport_params_test.go
+++ b/quic/transport_params_test.go
@@ -13,6 +13,8 @@ import (
"reflect"
"testing"
"time"
+
+ "golang.org/x/net/internal/quic/quicwire"
)
func TestTransportParametersMarshalUnmarshal(t *testing.T) {
@@ -334,9 +336,9 @@ func TestTransportParameterMaxIdleTimeoutOverflowsDuration(t *testing.T) {
tooManyMS := 1 + (math.MaxInt64 / uint64(time.Millisecond))
var enc []byte
- enc = appendVarint(enc, paramMaxIdleTimeout)
- enc = appendVarint(enc, uint64(sizeVarint(tooManyMS)))
- enc = appendVarint(enc, uint64(tooManyMS))
+ enc = quicwire.AppendVarint(enc, paramMaxIdleTimeout)
+ enc = quicwire.AppendVarint(enc, uint64(quicwire.SizeVarint(tooManyMS)))
+ enc = quicwire.AppendVarint(enc, uint64(tooManyMS))
dec, err := unmarshalTransportParams(enc)
if err != nil {
diff --git a/quic/udp_packetconn.go b/quic/udp_packetconn.go
new file mode 100644
index 0000000000..85ce349ff1
--- /dev/null
+++ b/quic/udp_packetconn.go
@@ -0,0 +1,69 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build go1.21
+
+package quic
+
+import (
+ "net"
+ "net/netip"
+)
+
+// netPacketConn is a packetConn implementation wrapping a net.PacketConn.
+//
+// This is mostly useful for tests, since PacketConn doesn't provide access to
+// important features such as identifying the local address packets were received on.
+type netPacketConn struct {
+ c net.PacketConn
+ localAddr netip.AddrPort
+}
+
+func newNetPacketConn(pc net.PacketConn) (*netPacketConn, error) {
+ addr, err := addrPortFromAddr(pc.LocalAddr())
+ if err != nil {
+ return nil, err
+ }
+ return &netPacketConn{
+ c: pc,
+ localAddr: addr,
+ }, nil
+}
+
+func (c *netPacketConn) Close() error {
+ return c.c.Close()
+}
+
+func (c *netPacketConn) LocalAddr() netip.AddrPort {
+ return c.localAddr
+}
+
+func (c *netPacketConn) Read(f func(*datagram)) {
+ for {
+ dgram := newDatagram()
+ n, peerAddr, err := c.c.ReadFrom(dgram.b)
+ if err != nil {
+ return
+ }
+ dgram.peerAddr, err = addrPortFromAddr(peerAddr)
+ if err != nil {
+ continue
+ }
+ dgram.b = dgram.b[:n]
+ f(dgram)
+ }
+}
+
+func (c *netPacketConn) Write(dgram datagram) error {
+ _, err := c.c.WriteTo(dgram.b, net.UDPAddrFromAddrPort(dgram.peerAddr))
+ return err
+}
+
+func addrPortFromAddr(addr net.Addr) (netip.AddrPort, error) {
+ switch a := addr.(type) {
+ case *net.UDPAddr:
+ return a.AddrPort(), nil
+ }
+ return netip.ParseAddrPort(addr.String())
+}
diff --git a/route/address.go b/route/address.go
index 5443d67223..492838a7fe 100644
--- a/route/address.go
+++ b/route/address.go
@@ -170,29 +170,56 @@ func (a *Inet6Addr) marshal(b []byte) (int, error) {
// parseInetAddr parses b as an internet address for IPv4 or IPv6.
func parseInetAddr(af int, b []byte) (Addr, error) {
+ const (
+ off4 = 4 // offset of in_addr
+ off6 = 8 // offset of in6_addr
+ ipv4Len = 4 // length of IPv4 address in bytes
+ ipv6Len = 16 // length of IPv6 address in bytes
+ )
switch af {
case syscall.AF_INET:
- if len(b) < sizeofSockaddrInet {
+ if len(b) < int(b[0]) {
return nil, errInvalidAddr
}
+ sockAddrLen := int(b[0])
a := &Inet4Addr{}
- copy(a.IP[:], b[4:8])
+ // sockAddrLen of 0 is valid and represents 0.0.0.0
+ if sockAddrLen > off4 {
+ // Calculate how many bytes of the address to copy:
+ // either full IPv4 length or the available length.
+ n := off4 + ipv4Len
+ if sockAddrLen < n {
+ n = sockAddrLen
+ }
+ copy(a.IP[:], b[off4:n])
+ }
return a, nil
case syscall.AF_INET6:
- if len(b) < sizeofSockaddrInet6 {
+ if len(b) < int(b[0]) {
return nil, errInvalidAddr
}
- a := &Inet6Addr{ZoneID: int(nativeEndian.Uint32(b[24:28]))}
- copy(a.IP[:], b[8:24])
- if a.IP[0] == 0xfe && a.IP[1]&0xc0 == 0x80 || a.IP[0] == 0xff && (a.IP[1]&0x0f == 0x01 || a.IP[1]&0x0f == 0x02) {
- // KAME based IPv6 protocol stack usually
- // embeds the interface index in the
- // interface-local or link-local address as
- // the kernel-internal form.
- id := int(bigEndian.Uint16(a.IP[2:4]))
- if id != 0 {
- a.ZoneID = id
- a.IP[2], a.IP[3] = 0, 0
+ sockAddrLen := int(b[0])
+ a := &Inet6Addr{}
+ // sockAddrLen of 0 is valid and represents ::
+ if sockAddrLen > off6 {
+ n := off6 + ipv6Len
+ if sockAddrLen < n {
+ n = sockAddrLen
+ }
+ if sockAddrLen == sizeofSockaddrInet6 {
+ a.ZoneID = int(nativeEndian.Uint32(b[24:28]))
+ }
+ copy(a.IP[:], b[off6:n])
+ if a.IP[0] == 0xfe && a.IP[1]&0xc0 == 0x80 || a.IP[0] == 0xff && (a.IP[1]&0x0f == 0x01 || a.IP[1]&0x0f == 0x02) {
+ // KAME based IPv6 protocol stack usually
+ // embeds the interface index in the
+ // interface-local or link-local address as
+ // the kernel-internal form.
+ id := int(bigEndian.Uint16(a.IP[2:4]))
+ if id != 0 {
+ a.ZoneID = id
+ a.IP[2], a.IP[3] = 0, 0
+ }
}
}
return a, nil
@@ -369,13 +396,19 @@ func marshalAddrs(b []byte, as []Addr) (uint, error) {
func parseAddrs(attrs uint, fn func(int, []byte) (int, Addr, error), b []byte) ([]Addr, error) {
var as [syscall.RTAX_MAX]Addr
af := int(syscall.AF_UNSPEC)
+ isInet := func(fam int) bool {
+ return fam == syscall.AF_INET || fam == syscall.AF_INET6
+ }
+ isMask := func(addrType uint) bool {
+ return addrType == syscall.RTAX_NETMASK || addrType == syscall.RTAX_GENMASK
+ }
for i := uint(0); i < syscall.RTAX_MAX && len(b) >= roundup(0); i++ {
if attrs&(1<
+ // locks: inits:
+ // sockaddrs:
+ // :: fe80::2d0:4cff:fe10:15d2 ::
+ {
+ syscall.RTA_DST | syscall.RTA_GATEWAY | syscall.RTA_NETMASK,
+ parseKernelInetAddr,
+ []byte{
+ 0x1c, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+
+ 0x1c, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x02, 0xd0, 0x4c, 0xff, 0xfe, 0x10, 0x15, 0xd2,
+ 0x00, 0x00, 0x00, 0x00,
+
+ 0x02, 0x1e, 0x00, 0x00,
+ },
+ []Addr{
+ &Inet6Addr{},
+ &Inet6Addr{IP: [16]byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xd0, 0x4c, 0xff, 0xfe, 0x10, 0x15, 0xd2}},
+ &Inet6Addr{},
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ },
+ },
+ // golang/go#70528, the kernel can produce addresses of length 0
+ {
+ syscall.RTA_DST | syscall.RTA_GATEWAY | syscall.RTA_NETMASK,
+ parseKernelInetAddr,
+ []byte{
+ 0x00, 0x1e, 0x00, 0x00,
+
+ 0x1c, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0xfe, 0x80, 0x00, 0x21, 0x00, 0x00, 0x00, 0x00,
+ 0xf2, 0x2f, 0x4b, 0xff, 0xfe, 0x09, 0x3b, 0xff,
+ 0x00, 0x00, 0x00, 0x00,
+
+ 0x0e, 0x1e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00,
+ },
+ []Addr{
+ &Inet6Addr{IP: [16]byte{}},
+ &Inet6Addr{IP: [16]byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf2, 0x2f, 0x4b, 0xff, 0xfe, 0x09, 0x3b, 0xff}, ZoneID: 33},
+ &Inet6Addr{IP: [16]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ },
+ },
+ // Additional case: golang/go/issues/70528#issuecomment-2498692877
+ {
+ syscall.RTA_DST | syscall.RTA_GATEWAY | syscall.RTA_NETMASK,
+ parseKernelInetAddr,
+ []byte{
+ 0x84, 0x00, 0x05, 0x04, 0x01, 0x00, 0x00, 0x00, 0x03, 0x08, 0x00, 0x01, 0x15, 0x00, 0x00, 0x00,
+ 0x1B, 0x01, 0x00, 0x00, 0xF5, 0x5A, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x02, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00,
+ 0x14, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00,
+ },
+ []Addr{
+ &Inet4Addr{IP: [4]byte{0x0, 0x0, 0x0, 0x0}},
+ nil,
+ nil,
nil,
nil,
nil,
diff --git a/route/defs_darwin.go b/route/defs_darwin.go
index ec56ca02e1..46a4ed6694 100644
--- a/route/defs_darwin.go
+++ b/route/defs_darwin.go
@@ -24,14 +24,10 @@ const (
sizeofIfmaMsghdrDarwin15 = C.sizeof_struct_ifma_msghdr
sizeofIfMsghdr2Darwin15 = C.sizeof_struct_if_msghdr2
sizeofIfmaMsghdr2Darwin15 = C.sizeof_struct_ifma_msghdr2
- sizeofIfDataDarwin15 = C.sizeof_struct_if_data
- sizeofIfData64Darwin15 = C.sizeof_struct_if_data64
sizeofRtMsghdrDarwin15 = C.sizeof_struct_rt_msghdr
sizeofRtMsghdr2Darwin15 = C.sizeof_struct_rt_msghdr2
- sizeofRtMetricsDarwin15 = C.sizeof_struct_rt_metrics
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/defs_dragonfly.go b/route/defs_dragonfly.go
index 9bf202dda4..52aa700a6d 100644
--- a/route/defs_dragonfly.go
+++ b/route/defs_dragonfly.go
@@ -47,10 +47,8 @@ const (
sizeofIfaMsghdrDragonFlyBSD58 = C.sizeof_struct_ifa_msghdr_dfly58
- sizeofRtMsghdrDragonFlyBSD4 = C.sizeof_struct_rt_msghdr
- sizeofRtMetricsDragonFlyBSD4 = C.sizeof_struct_rt_metrics
+ sizeofRtMsghdrDragonFlyBSD4 = C.sizeof_struct_rt_msghdr
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/defs_freebsd.go b/route/defs_freebsd.go
index abb2dc0957..68778f2d16 100644
--- a/route/defs_freebsd.go
+++ b/route/defs_freebsd.go
@@ -220,7 +220,6 @@ import "C"
const (
sizeofIfMsghdrlFreeBSD10 = C.sizeof_struct_if_msghdrl
sizeofIfaMsghdrFreeBSD10 = C.sizeof_struct_ifa_msghdr
- sizeofIfaMsghdrlFreeBSD10 = C.sizeof_struct_ifa_msghdrl
sizeofIfmaMsghdrFreeBSD10 = C.sizeof_struct_ifma_msghdr
sizeofIfAnnouncemsghdrFreeBSD10 = C.sizeof_struct_if_announcemsghdr
@@ -233,15 +232,7 @@ const (
sizeofIfMsghdrFreeBSD10 = C.sizeof_struct_if_msghdr_freebsd10
sizeofIfMsghdrFreeBSD11 = C.sizeof_struct_if_msghdr_freebsd11
- sizeofIfDataFreeBSD7 = C.sizeof_struct_if_data_freebsd7
- sizeofIfDataFreeBSD8 = C.sizeof_struct_if_data_freebsd8
- sizeofIfDataFreeBSD9 = C.sizeof_struct_if_data_freebsd9
- sizeofIfDataFreeBSD10 = C.sizeof_struct_if_data_freebsd10
- sizeofIfDataFreeBSD11 = C.sizeof_struct_if_data_freebsd11
-
- sizeofIfMsghdrlFreeBSD10Emu = C.sizeof_struct_if_msghdrl
sizeofIfaMsghdrFreeBSD10Emu = C.sizeof_struct_ifa_msghdr
- sizeofIfaMsghdrlFreeBSD10Emu = C.sizeof_struct_ifa_msghdrl
sizeofIfmaMsghdrFreeBSD10Emu = C.sizeof_struct_ifma_msghdr
sizeofIfAnnouncemsghdrFreeBSD10Emu = C.sizeof_struct_if_announcemsghdr
@@ -254,13 +245,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = C.sizeof_struct_if_msghdr_freebsd10
sizeofIfMsghdrFreeBSD11Emu = C.sizeof_struct_if_msghdr_freebsd11
- sizeofIfDataFreeBSD7Emu = C.sizeof_struct_if_data_freebsd7
- sizeofIfDataFreeBSD8Emu = C.sizeof_struct_if_data_freebsd8
- sizeofIfDataFreeBSD9Emu = C.sizeof_struct_if_data_freebsd9
- sizeofIfDataFreeBSD10Emu = C.sizeof_struct_if_data_freebsd10
- sizeofIfDataFreeBSD11Emu = C.sizeof_struct_if_data_freebsd11
-
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/defs_netbsd.go b/route/defs_netbsd.go
index 8e89934c5a..fb60f43c83 100644
--- a/route/defs_netbsd.go
+++ b/route/defs_netbsd.go
@@ -23,10 +23,8 @@ const (
sizeofIfaMsghdrNetBSD7 = C.sizeof_struct_ifa_msghdr
sizeofIfAnnouncemsghdrNetBSD7 = C.sizeof_struct_if_announcemsghdr
- sizeofRtMsghdrNetBSD7 = C.sizeof_struct_rt_msghdr
- sizeofRtMetricsNetBSD7 = C.sizeof_struct_rt_metrics
+ sizeofRtMsghdrNetBSD7 = C.sizeof_struct_rt_msghdr
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/defs_openbsd.go b/route/defs_openbsd.go
index 8f3218bc63..471558d9ef 100644
--- a/route/defs_openbsd.go
+++ b/route/defs_openbsd.go
@@ -21,7 +21,6 @@ import "C"
const (
sizeofRtMsghdr = C.sizeof_struct_rt_msghdr
- sizeofSockaddrStorage = C.sizeof_struct_sockaddr_storage
- sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
- sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
+ sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
+ sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)
diff --git a/route/example_darwin_test.go b/route/example_darwin_test.go
new file mode 100644
index 0000000000..e442c3ecf7
--- /dev/null
+++ b/route/example_darwin_test.go
@@ -0,0 +1,70 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package route_test
+
+import (
+ "fmt"
+ "net/netip"
+ "os"
+ "syscall"
+
+ "golang.org/x/net/route"
+ "golang.org/x/sys/unix"
+)
+
+// This example demonstrates how to parse a response to RTM_GET request.
+func ExampleParseRIB() {
+ fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ if err != nil {
+ return
+ }
+ defer unix.Close(fd)
+
+ // Create a RouteMessage with RTM_GET type
+ rtm := &route.RouteMessage{
+ Version: syscall.RTM_VERSION,
+ Type: unix.RTM_GET,
+ ID: uintptr(os.Getpid()),
+ Seq: 0,
+ Addrs: []route.Addr{
+ &route.Inet4Addr{IP: [4]byte{127, 0, 0, 0}},
+ },
+ }
+
+ // Marshal the message into bytes
+ msgBytes, err := rtm.Marshal()
+ if err != nil {
+ return
+ }
+
+ // Send the message over the routing socket
+ _, err = unix.Write(fd, msgBytes)
+ if err != nil {
+ return
+ }
+
+ // Read the response from the routing socket
+ var buf [2 << 10]byte
+ n, err := unix.Read(fd, buf[:])
+ if err != nil {
+ return
+ }
+
+ // Parse the response messages
+ msgs, err := route.ParseRIB(route.RIBTypeRoute, buf[:n])
+ if err != nil {
+ return
+ }
+ routeMsg, ok := msgs[0].(*route.RouteMessage)
+ if !ok {
+ return
+ }
+ netmask, ok := routeMsg.Addrs[2].(*route.Inet4Addr)
+ if !ok {
+ return
+ }
+ fmt.Println(netip.AddrFrom4(netmask.IP))
+ // Output: 255.0.0.0
+}
diff --git a/route/sys_netbsd.go b/route/sys_netbsd.go
index be4460e13f..c6bb6bc8a2 100644
--- a/route/sys_netbsd.go
+++ b/route/sys_netbsd.go
@@ -25,7 +25,7 @@ func (m *RouteMessage) Sys() []Sys {
}
}
-// RouteMetrics represents route metrics.
+// InterfaceMetrics represents route metrics.
type InterfaceMetrics struct {
Type int // interface type
MTU int // maximum transmission unit
diff --git a/route/zsys_darwin.go b/route/zsys_darwin.go
index 56a0c66f44..adaa460026 100644
--- a/route/zsys_darwin.go
+++ b/route/zsys_darwin.go
@@ -9,14 +9,10 @@ const (
sizeofIfmaMsghdrDarwin15 = 0x10
sizeofIfMsghdr2Darwin15 = 0xa0
sizeofIfmaMsghdr2Darwin15 = 0x14
- sizeofIfDataDarwin15 = 0x60
- sizeofIfData64Darwin15 = 0x80
sizeofRtMsghdrDarwin15 = 0x5c
sizeofRtMsghdr2Darwin15 = 0x5c
- sizeofRtMetricsDarwin15 = 0x38
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_dragonfly.go b/route/zsys_dragonfly.go
index f7c7a60cd6..209cb20af8 100644
--- a/route/zsys_dragonfly.go
+++ b/route/zsys_dragonfly.go
@@ -11,10 +11,8 @@ const (
sizeofIfaMsghdrDragonFlyBSD58 = 0x18
- sizeofRtMsghdrDragonFlyBSD4 = 0x98
- sizeofRtMetricsDragonFlyBSD4 = 0x70
+ sizeofRtMsghdrDragonFlyBSD4 = 0x98
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_386.go b/route/zsys_freebsd_386.go
index 3f985c7ee9..ec617772b2 100644
--- a/route/zsys_freebsd_386.go
+++ b/route/zsys_freebsd_386.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0x68
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0x6c
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,18 +18,10 @@ const (
sizeofIfMsghdrFreeBSD10 = 0x64
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x50
- sizeofIfDataFreeBSD8 = 0x50
- sizeofIfDataFreeBSD9 = 0x50
- sizeofIfDataFreeBSD10 = 0x54
- sizeofIfDataFreeBSD11 = 0x98
-
// MODIFIED BY HAND FOR 386 EMULATION ON AMD64
// 386 EMULATION USES THE UNDERLYING RAW DATA LAYOUT
- sizeofIfMsghdrlFreeBSD10Emu = 0xb0
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0xb0
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -43,13 +34,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0xa8
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x98
- sizeofIfDataFreeBSD8Emu = 0x98
- sizeofIfDataFreeBSD9Emu = 0x98
- sizeofIfDataFreeBSD10Emu = 0x98
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_amd64.go b/route/zsys_freebsd_amd64.go
index 9293393698..3d7f31d13e 100644
--- a/route/zsys_freebsd_amd64.go
+++ b/route/zsys_freebsd_amd64.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0xb0
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0xb0
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,15 +18,7 @@ const (
sizeofIfMsghdrFreeBSD10 = 0xa8
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x98
- sizeofIfDataFreeBSD8 = 0x98
- sizeofIfDataFreeBSD9 = 0x98
- sizeofIfDataFreeBSD10 = 0x98
- sizeofIfDataFreeBSD11 = 0x98
-
- sizeofIfMsghdrlFreeBSD10Emu = 0xb0
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0xb0
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -40,13 +31,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0xa8
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x98
- sizeofIfDataFreeBSD8Emu = 0x98
- sizeofIfDataFreeBSD9Emu = 0x98
- sizeofIfDataFreeBSD10Emu = 0x98
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_arm.go b/route/zsys_freebsd_arm.go
index a2bdb4ad3b..931afa3931 100644
--- a/route/zsys_freebsd_arm.go
+++ b/route/zsys_freebsd_arm.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0x68
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0x6c
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,15 +18,7 @@ const (
sizeofIfMsghdrFreeBSD10 = 0x70
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x60
- sizeofIfDataFreeBSD8 = 0x60
- sizeofIfDataFreeBSD9 = 0x60
- sizeofIfDataFreeBSD10 = 0x60
- sizeofIfDataFreeBSD11 = 0x98
-
- sizeofIfMsghdrlFreeBSD10Emu = 0x68
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0x6c
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -40,13 +31,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0x70
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x60
- sizeofIfDataFreeBSD8Emu = 0x60
- sizeofIfDataFreeBSD9Emu = 0x60
- sizeofIfDataFreeBSD10Emu = 0x60
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_arm64.go b/route/zsys_freebsd_arm64.go
index 9293393698..3d7f31d13e 100644
--- a/route/zsys_freebsd_arm64.go
+++ b/route/zsys_freebsd_arm64.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0xb0
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0xb0
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,15 +18,7 @@ const (
sizeofIfMsghdrFreeBSD10 = 0xa8
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x98
- sizeofIfDataFreeBSD8 = 0x98
- sizeofIfDataFreeBSD9 = 0x98
- sizeofIfDataFreeBSD10 = 0x98
- sizeofIfDataFreeBSD11 = 0x98
-
- sizeofIfMsghdrlFreeBSD10Emu = 0xb0
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0xb0
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -40,13 +31,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0xa8
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x98
- sizeofIfDataFreeBSD8Emu = 0x98
- sizeofIfDataFreeBSD9Emu = 0x98
- sizeofIfDataFreeBSD10Emu = 0x98
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_freebsd_riscv64.go b/route/zsys_freebsd_riscv64.go
index 9293393698..3d7f31d13e 100644
--- a/route/zsys_freebsd_riscv64.go
+++ b/route/zsys_freebsd_riscv64.go
@@ -6,7 +6,6 @@ package route
const (
sizeofIfMsghdrlFreeBSD10 = 0xb0
sizeofIfaMsghdrFreeBSD10 = 0x14
- sizeofIfaMsghdrlFreeBSD10 = 0xb0
sizeofIfmaMsghdrFreeBSD10 = 0x10
sizeofIfAnnouncemsghdrFreeBSD10 = 0x18
@@ -19,15 +18,7 @@ const (
sizeofIfMsghdrFreeBSD10 = 0xa8
sizeofIfMsghdrFreeBSD11 = 0xa8
- sizeofIfDataFreeBSD7 = 0x98
- sizeofIfDataFreeBSD8 = 0x98
- sizeofIfDataFreeBSD9 = 0x98
- sizeofIfDataFreeBSD10 = 0x98
- sizeofIfDataFreeBSD11 = 0x98
-
- sizeofIfMsghdrlFreeBSD10Emu = 0xb0
sizeofIfaMsghdrFreeBSD10Emu = 0x14
- sizeofIfaMsghdrlFreeBSD10Emu = 0xb0
sizeofIfmaMsghdrFreeBSD10Emu = 0x10
sizeofIfAnnouncemsghdrFreeBSD10Emu = 0x18
@@ -40,13 +31,6 @@ const (
sizeofIfMsghdrFreeBSD10Emu = 0xa8
sizeofIfMsghdrFreeBSD11Emu = 0xa8
- sizeofIfDataFreeBSD7Emu = 0x98
- sizeofIfDataFreeBSD8Emu = 0x98
- sizeofIfDataFreeBSD9Emu = 0x98
- sizeofIfDataFreeBSD10Emu = 0x98
- sizeofIfDataFreeBSD11Emu = 0x98
-
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_netbsd.go b/route/zsys_netbsd.go
index eaffe8c408..90ce707d47 100644
--- a/route/zsys_netbsd.go
+++ b/route/zsys_netbsd.go
@@ -8,10 +8,8 @@ const (
sizeofIfaMsghdrNetBSD7 = 0x18
sizeofIfAnnouncemsghdrNetBSD7 = 0x18
- sizeofRtMsghdrNetBSD7 = 0x78
- sizeofRtMetricsNetBSD7 = 0x50
+ sizeofRtMsghdrNetBSD7 = 0x78
- sizeofSockaddrStorage = 0x80
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/route/zsys_openbsd.go b/route/zsys_openbsd.go
index b11b812680..64fbdd98fb 100644
--- a/route/zsys_openbsd.go
+++ b/route/zsys_openbsd.go
@@ -6,7 +6,6 @@ package route
const (
sizeofRtMsghdr = 0x60
- sizeofSockaddrStorage = 0x100
- sizeofSockaddrInet = 0x10
- sizeofSockaddrInet6 = 0x1c
+ sizeofSockaddrInet = 0x10
+ sizeofSockaddrInet6 = 0x1c
)
diff --git a/webdav/file_test.go b/webdav/file_test.go
index 3af53fde31..c9313dc5bb 100644
--- a/webdav/file_test.go
+++ b/webdav/file_test.go
@@ -517,12 +517,7 @@ func TestDir(t *testing.T) {
t.Skip("see golang.org/issue/11453")
}
- td, err := os.MkdirTemp("", "webdav-test")
- if err != nil {
- t.Fatal(err)
- }
- defer os.RemoveAll(td)
- testFS(t, Dir(td))
+ testFS(t, Dir(t.TempDir()))
}
func TestMemFS(t *testing.T) {
diff --git a/websocket/hybi_test.go b/websocket/hybi_test.go
index 9504aa2d30..f0715d3f6f 100644
--- a/websocket/hybi_test.go
+++ b/websocket/hybi_test.go
@@ -163,7 +163,7 @@ Sec-WebSocket-Protocol: chat
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
- t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
+ t.Errorf("%s expected %q but got %q", k, v, req.Header.Get(k))
}
}
}
diff --git a/websocket/websocket.go b/websocket/websocket.go
index 923a5780ec..ac76165ceb 100644
--- a/websocket/websocket.go
+++ b/websocket/websocket.go
@@ -8,7 +8,7 @@
// This package currently lacks some features found in an alternative
// and more actively maintained WebSocket package:
//
-// https://pkg.go.dev/nhooyr.io/websocket
+// https://pkg.go.dev/github.com/coder/websocket
package websocket // import "golang.org/x/net/websocket"
import (
diff --git a/xsrftoken/xsrf.go b/xsrftoken/xsrf.go
index 3ca5d5b9f5..e808e6dd80 100644
--- a/xsrftoken/xsrf.go
+++ b/xsrftoken/xsrf.go
@@ -45,10 +45,9 @@ func generateTokenAtTime(key, userID, actionID string, now time.Time) string {
h := hmac.New(sha1.New, []byte(key))
fmt.Fprintf(h, "%s:%s:%d", clean(userID), clean(actionID), milliTime)
- // Get the padded base64 string then removing the padding.
+ // Get the no padding base64 string.
tok := string(h.Sum(nil))
- tok = base64.URLEncoding.EncodeToString([]byte(tok))
- tok = strings.TrimRight(tok, "=")
+ tok = base64.RawURLEncoding.EncodeToString([]byte(tok))
return fmt.Sprintf("%s:%d", tok, milliTime)
}