8000 fix: Pass cursor parameter to server (#745) · zendobk/mcp-python@e80c015 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e80c015

Browse files
authored
fix: Pass cursor parameter to server (modelcontextprotocol#745)
1 parent 2ca2de7 commit e80c015

File tree

5 files changed

+306
-73
lines changed

5 files changed

+306
-73
lines changed

src/mcp/client/session.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ async def list_resources(
209209
types.ClientRequest(
210210
types.ListResourcesRequest(
211211
method="resources/list",
212-
cursor=cursor,
212+
params=types.PaginatedRequestParams(cursor=cursor)
213+
if cursor is not None
214+
else None,
213215
)
214216
),
215217
types.ListResourcesResult,
@@ -223,7 +225,9 @@ async def list_resource_templates(
223225
types.ClientRequest(
224226
types.ListResourceTemplatesRequest(
225227
method="resources/templates/list",
226-
cursor=cursor,
228+
params=types.PaginatedRequestParams(cursor=cursor)
229+
if cursor is not None
230+
else None,
227231
)
228232
),
229233
types.ListResourceTemplatesResult,
@@ -295,7 +299,9 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu
295299
types.ClientRequest(
296300
types.ListPromptsRequest(
297301
method="prompts/list",
298-
cursor=cursor,
302+
params=types.PaginatedRequestParams(cursor=cursor)
303+
if cursor is not None
304+
else None,
299305
)
300306
),
301307
types.ListPromptsResult,
@@ -340,7 +346,9 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult:
340346
types.ClientRequest(
341347
types.ListToolsRequest(
342348
method="tools/list",
343-
cursor=cursor,
349+
params=types.PaginatedRequestParams(cursor=cursor)
350+
if cursor is not None
351+
else None,
344352
)
345353
),
346354
types.ListToolsResult,

src/mcp/types.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ class Meta(BaseModel):
5353
meta: Meta | None = Field(alias="_meta", default=None)
5454

5555

56+
class PaginatedRequestParams(RequestParams):
57+
cursor: Cursor | None = None
58+
"""
59+
An opaque token representing the current pagination position.
60+
If provided, the server should return results starting after this cursor.
61+
"""
62+
63+
5664
class NotificationParams(BaseModel):
5765
class Meta(BaseModel):
5866
model_config = ConfigDict(extra="allow")
@@ -79,12 +87,13 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
7987
model_config = ConfigDict(extra="allow")
8088

8189

82-
class PaginatedRequest(Request[RequestParamsT, MethodT]):
83-
cursor: Cursor | None = None
84-
"""
85-
An opaque token representing the current pagination position.
86-
If provided, the server should return results starting after this cursor.
87-
"""
90+
class PaginatedRequest(
91+
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]
92+
):
93+
"""Base class for paginated requests,
94+
matching the schema's PaginatedRequest interface."""
95+
96+
params: PaginatedRequestParams | None = None
8897

8998

9099
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@@ -358,13 +367,10 @@ class ProgressNotification(
358367
params: ProgressNotificationParams
359368

360369

361-
class ListResourcesRequest(
362-
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
363-
):
370+
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
364371
"""Sent from the client to request a list of resources the server has."""
365372

366373
method: Literal["resources/list"]
367-
params: RequestParams | None = None
368374

369375

370376
class Annotations(BaseModel):
@@ -423,12 +429,11 @@ class ListResourcesResult(PaginatedResult):
423429

424430

425431
class ListResourceTemplatesRequest(
426-
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
432+
PaginatedRequest[Literal["resources/templates/list"]]
427433
):
428434
"""Sent from the client to request a list of resource templates the server has."""
429435

430436
method: Literal["resources/templates/list"]
431-
params: RequestParams | None = None
432437

433438

434439
class ListResourceTemplatesResult(PaginatedResult):
@@ -570,13 +575,10 @@ class ResourceUpdatedNotification(
570575
params: ResourceUpdatedNotificationParams
571576

572577

573-
class ListPromptsRequest(
574-
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
575-
):
578+
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
576579
"""Sent from the client to request a list of prompts and prompt templates."""
577580

578581
method: Literal["prompts/list"]
579-
params: RequestParams | None = None
580582

581583

582584
class PromptArgument(BaseModel):
@@ -703,11 +705,10 @@ class PromptListChangedNotification(
703705
params: NotificationParams | None = None
704706

705707

706-
class ListToolsRequest(PaginatedRequest[RequestParams | None, Lit 8000 eral["tools/list"]]):
708+
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
707709
"""Sent from the client to request a list of tools the server has."""
708710

709711
method: Literal["tools/list"]
710-
params: RequestParams | None = None
711712

712713

713714
class ToolAnnotations(BaseModel):
@@ -741,7 +742,7 @@ class ToolAnnotations(BaseModel):
741742

742743
idempotentHint: bool | None = None
743744
"""
744-
If true, calling the tool repeatedly with the same arguments
745+
If true, calling the tool repeatedly with the same arguments
745746
will have no additional effect on the its environment.
746747
(This property is meaningful only when `readOnlyHint == false`)
747748
Default: false

tests/client/conftest.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from contextlib import asynccontextmanager
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
import mcp.shared.memory
7+
from mcp.shared.message import SessionMessage
8+
from mcp.types import (
9+
JSONRPCNotification,
10+
JSONRPCRequest,
11+
)
12+
13+
14+
class SpyMemoryObjectSendStream:
15+
def __init__(self, original_stream):
16+
self.original_stream = original_stream
17+
self.sent_messages: list[SessionMessage] = []
18+
19+
async def send(self, message):
20+
self.sent_messages.append(message)
21+
await self.original_stream.send(message)
22+
23+
async def aclose(self):
24+
await self.original_stream.aclose()
25+
26+
async def __aenter__(self):
27+
return self
28+
29+
async def __aexit__(self, *args):
30+
await self.aclose()
31+
32+
33+
class StreamSpyCollection:
34+
def __init__(
35+
self,
36+
client_spy: SpyMemoryObjectSendStream,
37+
server_spy: SpyMemoryObjectSendStream,
38+
):
39+
self.client = client_spy
40+
self.server = server_spy
41+
42+
def clear(self) -> None:
43+
"""Clear all captured messages."""
44+
self.client.sent_messages.clear()
45+
self.server.sent_messages.clear()
46+
47+
def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
48+
"""Get client-sent requests, optionally filtered by method."""
49+
return [
50+
req.message.root
51+
for req in self.client.sent_messages
52+
if isinstance(req.message.root, JSONRPCRequest)
53+
and (method is None or req.message.root.method == method)
54+
]
55+
56+
def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]:
57+
"""Get server-sent requests, optionally filtered by method."""
58+
return [
59+
req.message.root
60+
for req in self.server.sent_messages
61+
if isinstance(req.message.root, JSONRPCRequest)
62+
and (method is None or req.message.root.method == method)
63+
]
64+
65+
def get_client_notifications(
66+
self, method: str | None = None
67+
) -> list[JSONRPCNotification]:
68+
"""Get client-sent notifications, optionally filtered by method."""
69+
return [
70+
notif.message.root
71+
for notif in self.client.sent_messages
72+
if isinstance(notif.message.root, JSONRPCNotification)
73+
and (method is None or notif.message.root.method == method)
74+
]
75+
76+
def get_server_notifications(
77+
self, method: str | None = None
78+
) -> list[JSONRPCNotification]:
79+
"""Get server-sent notifications, optionally filtered by method."""
80+
return [
81+
notif.message.root
82+
for notif in self.server.sent_messages
83+
if isinstance(notif.message.root, JSONRPCNotification)
84+
and (method is None or notif.message.root.method == method)
85+
]
86+
87+
88+
@pytest.fixture
89+
def stream_spy():
90+
"""Fixture that provides spies for both client and server write streams.
91+
92+
Example usage:
93+
async def test_something(stream_spy):
94+
# ... set up server and client ...
95+
96+
spies = stream_spy()
97+
98+
# Run some operation that sends messages
99+
await client.some_operation()
100+
101+
# Check the messages
102+
requests = spies.get_client_requests(method="some/method")
103+
assert len(requests) == 1
104+
105+
# Clear for the next operation
106+
spies.clear()
107+
"""
108+
client_spy = None
109+
server_spy = None
110+
111+
# Store references to our spy objects
112+
def capture_spies(c_spy, s_spy):
113+
nonlocal client_spy, server_spy
114+
client_spy = c_spy
115+
server_spy = s_spy
116+
117+
# Create patched version of stream creation
118+
original_create_streams = mcp.shared.memory.create_client_server_memory_streams
119+
120+
@asynccontextmanager
121+
async def patched_create_streams():
122+
async with original_create_streams() as (client_streams, server_streams):
123+
client_read, client_write = client_streams
124+
server_read, server_write = server_streams
125+
126+
# Create spy wrappers
127+
spy_client_write = SpyMemoryObjectSendStream(client_write)
128+
spy_server_write = SpyMemoryObjectSendStream(server_write)
129+
130+
# Capture references for the test to use
131+
capture_spies(spy_client_write, spy_server_write)
132+
133+
yield (client_read, spy_client_write), (server_read, spy_server_write)
134+
135+
# Apply the patch for the duration of the test
136+
with patch(
137+
"mcp.shared.memory.create_client_server_memory_streams", patched_create_streams
138+
):
139+
# Return a collection with helper methods
140+
def get_spy_collection() -> StreamSpyCollection:
141+
assert client_spy is not None, "client_spy was not initialized"
142+
assert server_spy is not None, "server_spy was not initialized"
143+
return StreamSpyCollection(client_spy, server_spy)
144+
145+
yield get_spy_collection

0 commit comments

Comments
 (0)
0