@@ -4,11 +4,11 @@ import (
4
4
"context"
5
5
"net/http"
6
6
"net/http/httptest"
7
+ "sync"
7
8
"testing"
8
9
"time"
9
10
10
11
"cdr.dev/slog"
11
- "github.com/coder/coder/v2/coderd/httpapi"
12
12
"github.com/coder/coder/v2/coderd/tracing"
13
13
"github.com/coder/coder/v2/testutil"
14
14
"github.com/coder/websocket"
@@ -96,15 +96,22 @@ func TestLoggerMiddleware_WebSocket(t *testing.T) {
96
96
sink := & fakeSink {}
97
97
logger := slog .Make (sink )
98
98
logger = logger .Leveled (slog .LevelDebug )
99
-
99
+ var wg sync. WaitGroup
100
100
// Create a test handler to simulate a WebSocket connection
101
101
testHandler := http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
102
- _ , err := websocket .Accept (rw , r , nil )
102
+
103
+ conn , err := websocket .Accept (rw , r , nil )
103
104
if err != nil {
104
- httpapi . Write ( ctx , rw , http . StatusBadRequest , nil )
105
+ t . Errorf ( "failed to accept websocket: %v" , err )
105
106
return
106
107
}
107
- time .Sleep (1000 )
108
+ defer conn .Close (websocket .StatusNormalClosure , "" )
109
+ defer wg .Done ()
110
+
111
+ // Send a couple of messages for testing
112
+ _ = conn .Write (ctx , websocket .MessageText , []byte ("ping" ))
113
+ _ = conn .Write (ctx , websocket .MessageText , []byte ("pong" ))
114
+
108
115
})
109
116
110
117
// Wrap the test handler with the Logger middleware
@@ -120,9 +127,10 @@ func TestLoggerMiddleware_WebSocket(t *testing.T) {
120
127
// Create a test HTTP request
121
128
srv := httptest .NewServer (customHandler )
122
129
defer srv .Close ()
123
-
130
+ wg . Add ( 1 )
124
131
// nolint: bodyclose
125
132
conn , _ , err := websocket .Dial (ctx , srv .URL , nil )
133
+ wg .Wait ()
126
134
if err != nil {
127
135
t .Fatalf ("failed to create WebSocket connection: %v" , err )
128
136
}
0 commit comments