8000 WIP · shingjan/python-sdk@a0e2f7f · GitHub
[go: up one dir, main page]

Skip to content

Commit a0e2f7f

Browse files
committed
WIP
1 parent 3f9f7c8 commit a0e2f7f

File tree

1 file changed

+121
-97
lines changed

1 file changed

+121
-97
lines changed

tests/client/test_sse_attempt.py

Lines changed: 121 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import anyio
2+
import asyncio
23
import pytest
34
from starlette.applications import Starlette
45
from starlette.routing import Mount, Route
@@ -24,32 +25,42 @@ async def sse_app(sse_transport):
2425
async def handle_sse(request):
2526
"""Handler for SSE connections."""
2627
async def event_generator():
27-
# Send initial connection event
28-
yield {
29-
"event": "endpoint",
30-
"data": "/messages",
31-
}
32-
33-
# Keep connection alive
34-
async with sse_transport.connect_sse(
35-
request.scope, request.receive, request._send
36-
) as streams:
37-
client_to_server, server_to_client = streams
38-
async for message in client_to_server:
28+
try:
29+
async with sse_transport.connect_sse(
30+
request.scope, request.receive, request._send
31+
) as streams:
32+
client_to_server, server_to_client = streams
33+
# Send initial connection event
3934
yield {
40-
"event": "message",
41-
"data": message.model_dump_json(),
35+
"event": "endpoint",
36+
"data": "/messages",
4237
}
4338

44-
return EventSourceResponse(event_generator())
39+
# Process messages
40+
async with anyio.create_task_group() as tg:
41+
try:
42+
async for message in client_to_server:
43+
if isinstance(message, Exception):
44+
break
45+
yield {
46+
"event": "message",
47+
"data": message.model_dump_json(),
48+
}
49+
except (asyncio.CancelledError, GeneratorExit):
50+
print('cancelled')
51+
return
52+
except Exception as e:
53+
print("unhandled exception:", e)
54+
return
55+
except Exception:
56+
# Log any unexpected errors but allow connection to close gracefully
57+
pass
4558

46-
async def handle_post(request):
47-
"""Handler for POST messages."""
48-
return Response(status_code=200)
59+
return EventSourceResponse(event_generator())
4960

5061
routes = [
5162
Route("/sse", endpoint=handle_sse),
52-
Route("/messages", endpoint=handle_post, methods=["POST"]),
63+
Mount("/messages", app=sse_transport.handle_post_message),
5364
]
5465

5566
return Starlette(routes=routes)
@@ -62,88 +73,101 @@ async def test_client(sse_app):
6273
async with httpx.AsyncClient(
6374
transport=transport,
6475
base_url="http://testserver",
65-
timeout=5.0,
76+
timeout=10.0,
6677
) as client:
6778
yield client
6879

6980

7081
@pytest.mark.anyio
7182
async def test_sse_connection(test_client):
7283
"""Test basic SSE connection and message exchange."""
73-
async with sse_client(
74-
"http://testserver/sse",
75-
headers={"Host": "testserver"},
76-
timeout=2,
77-
sse_read_timeout=1,
78-
client=test_client,
79-
) as (read_stream, write_stream):
80-
# Send a test message
81-
test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"})
82-
await write_stream.send(test_message)
83-
84-
# Receive echoed message
85-
async with read_stream:
86-
message = await read_stream.__anext__()
87-
assert isinstance(message, JSONRPCMessage)
88-
assert message.model_dump() == test_message.model_dump()
89-
90-
91-
@pytest.mark.anyio
92-
async def test_sse_read_timeout(test_client):
93-
"""Test that SSE client properly handles read timeouts."""
94-
with pytest.raises(ReadTimeout):
95-
async with sse_client(
96-
"http://testserver/sse",
97-
headers={"Host": "testserver"},
98-
timeout=2,
99-
sse_read_timeout=1,
100-
client=test_client,
101-
) as (read_stream, write_stream):
102-
async with read_stream:
103-
# This should timeout since no messages are being sent
104-
await read_stream.__anext__()
105-
106-
107-
@pytest.mark.anyio
108-
async def test_sse_connection_error(test_client):
109-
"""Test SSE client behavior with connection errors."""
110-
with pytest.raises(httpx.HTTPError):
111-
async with sse_client(
112-
"http://testserver/nonexistent",
113-
headers={"Host": "testserver"},
114-
timeout=2,
115-
client=test_client,
116-
):
117-
pass # Should not reach here
118-
119-
120-
@pytest.mark.anyio
121-
async def test_sse_multiple_messages(test_client):
122-
"""Test sending and receiving multiple SSE messages."""
123-
async with sse_client(
124-
"http://testserver/sse",
125-
headers={"Host": "testserver"},
126-
timeout=2,
127-
sse_read_timeout=1,
128-
client=test_client,
129-
) as (read_stream, write_stream):
130-
# Send multiple test messages
131-
messages = [
132-
JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"})
133-
for i in range(3)
134-
]
135-
136-
for msg in messages:
137-
await write_stream.send(msg)
138-
139-
# Receive all echoed messages
140-
received = []
141-
async with read_stream:
142-
for _ in range(len(messages)):
143-
message = await read_stream.__anext__()
144-
assert isinstance(message, JSONRPCMessage)
145-
received.append(message)
146-
147-
# Verify all messages were received in order
148-
for sent, received in zip(messages, received):
149-
assert sent.model_dump() == received.model_dump()
84+
async with anyio.create_task_group() as tg:
85+
try:
86+
async with sse_client(
87+
"http://testserver/sse",
88+
headers={"Host": "testserver"},
89+
timeout=5,
90+
sse_read_timeout=5,
91+
client=test_client,
92+
) as (read_stream, write_stream):
93+
# First get the initial endpoint message
94+
async with read_stream:
95+
init_message = await read_stream.__anext__()
96+
assert isinstance(init_message, JSONRPCMessage)
97+
98+
# Send a test message
99+
test_message = JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": "test"})
100+
await write_stream.send(test_message)
101+
102+
# Receive echoed message
103+
async with read_stream:
104+
message = await read_stream.__anext__()
105+
assert isinstance(message, JSONRPCMessage)
106+
assert message.model_dump() == test_message.model_dump()
107+
108+
# Explicitly close streams
109+
await write_stream.aclose()
110+
await read_stream.aclose()
111+
except Exception as e:
112+
pytest.fail(f"Test failed with error: {str(e)}")
113+
114+
115+
# @pytest.mark.anyio
116+
# async def test_sse_read_timeout(test_client):
117+
# """Test that SSE client properly handles read timeouts."""
118+
# with pytest.raises(ReadTimeout):
119+
# async with sse_client(
120+
# "http://testserver/sse",
121+
# headers={"Host": "testserver"},
122+
# timeout=5,
123+
# sse_read_timeout=2,
124+
# client=test_client,
125+
# ) as (read_stream, write_stream):
126+
# async with read_stream:
127+
# # This should timeout since no messages are being sent
128+
# await read_stream.__anext__()
129+
130+
131+
# @pytest.mark.anyio
132+
# async def test_sse_connection_error(test_client):
133+
# """Test SSE client behavior with connection errors."""
134+
# with pytest.raises(httpx.HTTPError):
135+
# async with sse_client(
136+
# "http://testserver/nonexistent",
137+
# headers={"Host": "testserver"},
138+
# timeout=5,
139+
# client=test_client,
140+
# ):
141+
# pass # Should not reach here
142+
143+
144+
# @pytest.mark.anyio
145+
# async def test_sse_multiple_messages(test_client):
146+
# """Test sending and receiving multiple SSE messages."""
147+
# async with sse_client(
148+
# "http://testserver/sse",
149+
# headers={"Host": "testserver"},
150+
# timeout=5,
151+
# sse_read_timeout=5,
152+
# client=test_client,
153+
# ) as (read_stream, write_stream):
154+
# # Send multiple test messages
155+
# messages = [
156+
# JSONRPCMessage.model_validate({"jsonrpc": "2.0", "method": f"test{i}"})
157+
# for i in range(3)
158+
# ]
159+
160+
# for msg in messages:
161+
# await write_stream.send(msg)
162+
163+
# # Receive all echoed messages
164+
# received = []
165+
# async with read_stream:
166+
# for _ in range(len(messages)):
167+
# message = await read_stream.__anext__()
168+
# assert isinstance(message, JSONRPCMessage)
169+
# received.append(message)
170+
171+
# # Verify all messages were received in order
172+
# for sent, received in zip(messages, received):
173+
# assert sent.model_dump() == received.model_dump()

0 commit comments

Comments
 (0)
0