8000 passing SSE client test · shingjan/python-sdk@e79a564 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

8000
Appearance settings

Commit e79a564

Browse files
committed
passing SSE client test
1 parent 66ccd1c commit e79a564

File tree

1 file changed

+49
-84
lines changed

1 file changed

+49
-84
lines changed

tests/shared/test_sse.py

Lines changed: 49 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,36 @@
33
import time
44
import json
55
import anyio
6+
from pydantic import AnyUrl
7+
from pydantic_core import Url
68
import pytest
79
import httpx
810
from typing import AsyncGenerator
911
from starlette.applications import Starlette
1012
from starlette.routing import Mount 8000 , Route
1113

14+
from mcp.client.session import ClientSession
15+
from mcp.client.sse import sse_client
1216
from mcp.server import Server
1317
from mcp.server.sse import SseServerTransport
14-
from mcp.types import TextContent, Tool
18+
from mcp.types import EmptyResult, InitializeResult, TextContent, TextResourceContents, Tool
19+
20+
SERVER_URL = "http://127.0.0.1:8765"
21+
SERVER_SSE_URL = f"{SERVER_URL}/sse"
22+
23+
SERVER_NAME = "test_server_for_SSE"
1524

1625
# Test server implementation
1726
class TestServer(Server):
1827
def __init__(self):
19-
super().__init__("test_server")
28+
super().__init__(SERVER_NAME)
29+
30+
@self.read_resource()
31+
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
32+
if uri.scheme == "foobar":
33+
return f"Read {uri.host}"
34+
# TODO: make this an error
35+
return "NOT FOUND"
2036

2137
@self.list_tools()
2238
async def handle_list_tools():
@@ -76,18 +92,18 @@ def server(server_app: Starlette):
7692
server_thread.join(timeout=0.1)
7793

7894
@pytest.fixture()
79-
async def client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
95+
async def http_client(server) -> AsyncGenerator[httpx.AsyncClient, None]:
8096
"""Create test client"""
81-
async with httpx.AsyncClient(base_url="http://127.0.0.1:8765") as client:
97+
async with httpx.AsyncClient(base_url=SERVER_URL) as client:
8298
yield client
8399

84100
# Tests
85101
@pytest.mark.anyio
86-
async def test_sse_connection(client: httpx.AsyncClient):
87-
"""Test SSE connection establishment"""
102+
async def test_raw_sse_connection(http_client: httpx.AsyncClient):
103+
"""Test the SSE connection establishment simply with an HTTP client."""
88104
async with anyio.create_task_group() as tg:
89105
async def connection_test():
90-
async with client.stream("GET", "/sse") as response:
106+
async with http_client.stream("GET", "/sse") as response:
91107
assert response.status_code == 200
92108
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
93109

@@ -105,84 +121,33 @@ async def connection_test():
105121
with anyio.fail_after(3):
106122
await connection_test()
107123

108-
@pytest.mark.anyio
109-
async def test_message_exchange(client: httpx.AsyncClient):
110-
"""Test full message exchange flow"""
111-
# Connect to SSE endpoint
112-
session_id = None
113-
endpoint_url = None
114-
115-
async with client.stream("GET", "/sse") as sse_response:
116-
assert sse_response.status_code == 200
117-
118-
# Get endpoint URL and session ID
119-
async for line in sse_response.aiter_lines():
120-
if line.startswith("data: "):
121-
endpoint_url = json.loads(line[6:])
122-
session_id = endpoint_url.split("session_id=")[1]
123-
break
124-
125-
assert endpoint_url and session_id
126-
127-
# Send initialize request
128-
init_request = {
129-
"jsonrpc": "2.0",
130-
"id": 1,
131-
"method": "initialize",
132-
"params": {
133-
"protocolVersion": "2024-11-05",
134-
"capabilities": {},
135-
"clientInfo": {
136-
"name": "test_client",
137-
"version": "1.0"
138-
}
139-
}
140-
}
141-
142-
response = await client.post(
143-
endpoint_url,
144-
json=init_request
145-
)
146-
assert response.status_code == 202
147-
148-
# Get initialize response from SSE stream
149-
async for line in sse_response.aiter_lines():
150-
if line.startswith("event: message"):
151-
data_line = next(sse_response.aiter_lines())
152-
response = json.loads(data_line[6:]) # Strip "data: " prefix
153-
assert response["jsonrpc"] == "2.0"
154-
assert response["id"] == 1
155-
assert "result" in response
156-
break
157124

158125
@pytest.mark.anyio
159-
async def test_invalid_session(client: httpx.AsyncClient):
160-
"""Test sending message with invalid session ID"""
161-
response = await client.post(
162-
"/messages/?session_id=invalid",
163-
json={"jsonrpc": "2.0", "method": "ping"}
164-
)
165-
assert response.status_code == 400
126+
async def test_sse_client_basic_connection(server):
127+
async with sse_client(SERVER_SSE_URL) as streams:
128+
async with ClientSession(*streams) as session:
129+
# Test initialization
130+
result = await session.initialize()
131+
assert isinstance(result, InitializeResult)
132+
assert result.serverInfo.name == SERVER_NAME
133+
134+
# Test ping
135+
ping_result = await session.send_ping()
136+
assert isinstance(ping_result, EmptyResult)
137+
138+
@pytest.fixture
139+
async def initialized_sse_client_session(server) -> AsyncGenerator[ClientSession, None]:
140+
async with sse_client(SERVER_SSE_URL) as streams:
141+
async with ClientSession(*streams) as session:
142+
await session.initialize()
143+
yield session
166144

167145
@pytest.mark.anyio
168-
async def test_connection_cleanup(server_app):
169-
"""Test that resources are cleaned up when client disconnects"""
170-
sse = next(
171-
route.app for route in server_app.routes
172-
if isinstance(route, Mount) and route.path == "/messages/"
173-
).transport
174-
175-
async with httpx.AsyncClient(app=server_app, base_url="http://test") as client:
176-
# Connect and get session ID
177-
async with client.stream("GET", "/sse") as response:
178-
for line in response.iter_lines():
179-
if line.startswith("data: "):
180-
endpoint_url = json.loads(line[6:])
181-
session_id = endpoint_url.split("session_id=")[1]
182-
break
183-
184-
assert len(sse._read_stream_writers) == 1
185-
186-
# After connection closes, writer should be cleaned up
187-
await anyio.sleep(0.1) # Give cleanup a moment
188-
assert len(sse._read_stream_writers) == 0
146+
async def test_sse_client_request_and_response(initialized_sse_client_session: ClientSession):
147+
session = initialized_sse_client_session
148+
# TODO: expect raise
149+
await session.read_resource(uri=AnyUrl("xxx://will-not-work"))
150+
response = await session.read_resource(uri=AnyUrl("foobar://should-work"))
151+
assert len(response.contents) == 1
152+
assert isinstance(response.contents[0], TextResourceContents)
153+
assert response.contents[0].text == "Read should-work"

0 commit comments

Comments
 (0)
0