4
4
from starlette .routing import Mount , Route
5
5
import httpx
6
6
from httpx import ReadTimeout , ASGITransport
7
+ from starlette .responses import Response
8
+ from sse_starlette .sse import EventSourceResponse
7
9
8
10
from mcp .client .sse import sse_client
9
11
from mcp .server .sse import SseServerTransport
@@ -21,17 +23,33 @@ async def sse_app(sse_transport):
21
23
"""Fixture that creates a Starlette app with SSE endpoints."""
22
24
async def handle_sse (request ):
23
25
"""Handler for SSE connections."""
24
- async with sse_transport .connect_sse (
25
- request .scope , request .receive , request ._send
26
- ) as streams :
27
- client_to_server , server_to_client = streams
28
- async for message in client_to_server :
29
- # Echo messages back for testing
30
- await server_to_client .send (message )
26
+ async def event_generator ():
27
+ # Send initial connection event
28
+ yield {
29
+ "event" : "endpoint" ,
30
+ "data" : "/messages" ,
31
+ }
32
+
33
+ # Keep connection alive
34
+ async with sse_transport .connect_sse (
35
+ request .scope , request .receive , request ._send
36
+ ) as streams :
37
+ client_to_server , server_to_client = streams
38
+ async for message in client_to_server :
39
+ yield {
40
+ "event" : "message" ,
41
+ "data" : message .model_dump_json (),
42
+ }
43
+
44
+ return EventSourceResponse (event_generator ())
45
+
46
+ async def handle_post (request ):
47
+ """Handler for POST messages."""
48
+ return Response (status_code = 200 )
31
49
32
50
routes = [
33
51
Route ("/sse" , endpoint = handle_sse ),
34
- Mount ("/messages" , app = sse_transport . handle_post_message ),
52
+ Route ("/messages" , endpoint = handle_post , methods = [ "POST" ] ),
35
53
]
36
54
37
55
return Starlette (routes = routes )
@@ -40,9 +58,11 @@ async def handle_sse(request):
40
58
@pytest .fixture
41
59
async def test_client (sse_app ):
42
60
"""Create a test client with ASGI transport."""
61
+ transport = ASGITransport (app = sse_app )
43
62
async with httpx .AsyncClient (
44
- transport = ASGITransport ( app = sse_app ) ,
63
+ transport = transport ,
45
64
base_url = "http://testserver" ,
65
+ timeout = 5.0 ,
46
66
) as client :
47
67
yield client
48
68
@@ -53,7 +73,8 @@ async def test_sse_connection(test_client):
53
73
async with sse_client (
54
74
"http://testserver/sse" ,
55
75
headers = {"Host" : "testserver" },
56
- timeout = 5 ,
76
+ timeout = 2 ,
77
+ sse_read_timeout = 1 ,
57
78
client = test_client ,
58
79
) as (read_stream , write_stream ):
59
80
# Send a test message
@@ -74,7 +95,7 @@ async def test_sse_read_timeout(test_client):
74
95
async with sse_client (
75
96
"http://testserver/sse" ,
76
97
headers = {"Host" : "testserver" },
77
- timeout = 5 ,
98
+ timeout = 2 ,
78
99
sse_read_timeout = 1 ,
79
100
client = test_client ,
80
101
) as (read_stream , write_stream ):
@@ -90,7 +111,7 @@ async def test_sse_connection_error(test_client):
90
111
async with sse_client (
91
112
"http://testserver/nonexistent" ,
92
113
headers = {"Host" : "testserver" },
93
- timeout = 5 ,
114
+ timeout = 2 ,
94
115
client = test_client ,
95
116
):
96
117
pass # Should not reach here
@@ -102,7 +123,8 @@ async def test_sse_multiple_messages(test_client):
102
123
async with sse_client (
103
124
"http://testserver/sse" ,
104
125
headers = {"Host" : "testserver" },
105
- timeout = 5 ,
126
+ timeout = 2 ,
127
+ sse_read_timeout = 1 ,
106
128
client = test_client ,
107
129
) as (read_stream , write_stream ):
108
130
# Send multiple test messages
0 commit comments