8000 Merge pull request #42 from modelcontextprotocol/davidsp/types · simonw/python-sdk@99c402d · GitHub
[go: up one dir, main page]

Skip to content

Commit 99c402d

Browse files
authored
Merge pull request modelcontextprotocol#42 from modelcontextprotocol/davidsp/types
Types Rework
2 parents 837309c + ec8c85e commit 99c402d

File tree

14 files changed

+279
-503
lines changed
  • tests
  • 14 files changed

    +279
    -503
    lines changed

    src/mcp/client/session.py

    Lines changed: 79 additions & 158 deletions
    Original file line numberDiff line numberDiff line change
    @@ -3,85 +3,56 @@
    33
    from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
    44
    from pydantic import AnyUrl
    55

    6+
    import mcp.types as types
    67
    from mcp.shared.session import BaseSession
    78
    from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
    8-
    from mcp.types import (
    9-
    LATEST_PROTOCOL_VERSION,
    10-
    CallToolResult,
    11-
    ClientCapabilities,
    12-
    ClientNotification,
    13-
    ClientRequest,
    14-
    ClientResult,
    15-
    CompleteResult,
    16-
    EmptyResult,
    17-
    GetPromptResult,
    18-
    Implementation,
    19-
    InitializedNotification,
    20-
    InitializeResult,
    21-
    JSONRPCMessage,
    22-
    ListPromptsResult,
    23-
    ListResourcesResult,
    24-
    ListToolsResult,
    25-
    LoggingLevel,
    26-
    PromptReference,
    27-
    ReadResourceResult,
    28-
    ResourceReference,
    29-
    RootsCapability,
    30-
    ServerNotification,
    31-
    ServerRequest,
    32-
    )
    339

    3410

    3511
    class ClientSession(
    3612
    BaseSession[
    37-
    ClientRequest,
    38-
    ClientNotification,
    39-
    ClientResult,
    40-
    ServerRequest,
    41-
    ServerNotification,
    13+
    types.ClientRequest,
    14+
    types.ClientNotification,
    15+
    types.ClientResult,
    16+
    types.ServerRequest,
    17+
    types.ServerNotification,
    4218
    ]
    4319
    ):
    4420
    def __init__(
    4521
    self,
    46-
    read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
    47-
    write_stream: MemoryObjectSendStream[JSONRPCMessage],
    22+
    read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
    23+
    write_stream: MemoryObjectSendStream[types.JSONRPCMessage],
    4824
    read_timeout_seconds: timedelta | None = None,
    4925
    ) -> None:
    5026
    super().__init__(
    5127
    read_stream,
    5228
    write_stream,
    53-
    ServerRequest,
    54-
    ServerNotification,
    29+
    types.ServerRequest,
    30+
    types.ServerNotification,
    5531
    read_timeout_seconds=read_timeout_seconds,
    5632
    )
    5733

    58-
    async def initialize(self) -> InitializeResult:
    59-
    from mcp.types import (
    60-
    InitializeRequest,
    61-
    InitializeRequestParams,
    62-
    )
    63-
    34+
    async def initialize(self) -> types.InitializeResult:
    6435
    result = await self.send_request(
    65-
    ClientRequest(
    66-
    InitializeRequest(
    36+
    types.ClientRequest(
    37+
    types.InitializeRequest(
    6738
    method="initialize",
    68-
    params=InitializeRequestParams(
    69-
    protocolVersion=LATEST_PROTOCOL_VERSION,
    70-
    capabilities=ClientCapabilities(
    39+
    params=types.InitializeRequestParams(
    40+
    protocolVersion=types.LATEST_PROTOCOL_VERSION,
    41+
    capabilities=types.ClientCapabilities(
    7142
    sampling=None,
    7243
    experimental=None,
    73-
    roots=RootsCapability(
    44+
    roots=types.RootsCapability(
    7445
    # TODO: Should this be based on whether we
    7546
    # _will_ send notifications, or only whether
    7647
    # they're supported?
    7748
    listChanged=True
    7849
    ),
    7950
    ),
    80-
    clientInfo=Implementation(name="mcp", version="0.1.0"),
    51+
    clientInfo=types.Implementation(name="mcp", version="0.1.0"),
    8152
    ),
    8253
    )
    8354
    ),
    84-
    InitializeResult,
    55+
    types.InitializeResult,
    8556
    )
    8657

    8758
    if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
    @@ -91,40 +62,33 @@ async def initialize(self) -> InitializeResult:
    9162
    )
    9263

    9364
    await self.send_notification(
    94-
    ClientNotification(
    95-
    InitializedNotification(method="notifications/initialized")
    65+
    types.ClientNotification(
    66+
    types.InitializedNotification(method="notifications/initialized")
    9667
    )
    9768
    )
    9869

    9970
    return result
    10071

    101-
    async def send_ping(self) -> EmptyResult:
    72+
    async def send_ping(self) -> types.EmptyResult:
    10273
    """Send a ping request."""
    103-
    from mcp.types import PingRequest
    104-
    10574
    return await self.send_request(
    106-
    ClientRequest(
    107-
    PingRequest(
    75+
    types.ClientRequest(
    76+
    types.PingRequest(
    10877
    method="ping",
    10978
    )
    11079
    ),
    111-
    EmptyResult,
    80+
    types.EmptyResult,
    11281
    )
    11382

    11483
    async def send_progress_notification(
    11584
    self, progress_token: str | int, progress: float, total: float | None = None
    11685
    ) -> None:
    11786
    """Send a progress notification."""
    118-
    from mcp.types import (
    119-
    ProgressNotification,
    120-
    ProgressNotificationParams,
    121-
    )
    122-
    12387
    await self.send_notification(
    124-
    ClientNotification(
    125-
    ProgressNotification(
    88+
    types.ClientNotification(
    89+
    types.ProgressNotification(
    12690
    method="notifications/progress",
    127-
    params=ProgressNotificationParams(
    91+
    params=types.ProgressNotificationParams(
    12892
    progressToken=progress_token,
    12993
    progress=progress,
    13094
    total=total,
    @@ -133,180 +97,137 @@ async def send_progress_notification(
    13397
    )
    13498
    )
    13599

    136-
    async def set_logging_level(self, level: LoggingLevel) -> EmptyResult:
    100+
    async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
    137101
    """Send a logging/setLevel request."""
    138-
    from mcp.types import (
    139-
    SetLevelRequest,
    140-
    SetLevelRequestParams,
    141-
    )
    142-
    143102
    return await self.send_request(
    144-
    ClientRequest(
    145-
    SetLevelRequest(
    103+
    types.ClientRequest(
    104+
    types.SetLevelRequest(
    146105
    method="logging/setLevel",
    147-
    params=SetLevelRequestParams(level=level),
    106+
    params=types.SetLevelRequestParams(level=level),
    148107
    )
    149108
    ),
    150-
    EmptyResult,
    109+
    types.EmptyResult,
    151110
    )
    152111

    153-
    async def list_resources(self) -> ListResourcesResult:
    112+
    async def list_resources(self) -> types.ListResourcesResult:
    154113
    """Send a resources/list request."""
    155-
    from mcp.types import (
    156-
    ListResourcesRequest,
    157-
    )
    158-
    159114
    return await self.send_request(
    160-
    ClientRequest(
    161-
    ListResourcesRequest(
    115+
    types.ClientRequest(
    116+
    types.ListResourcesRequest(
    162117
    method="resources/list",
    163118
    )
    164119
    ),
    165-
    ListResourcesResult,
    120+
    types.ListResourcesResult,
    166121
    )
    167122

    168-
    async def read_resource(self, uri: AnyUrl) -> ReadResourceResult:
    123+
    async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
    169124
    """Send a resources/read request."""
    170-
    from mcp.types import (
    171-
    ReadResourceRequest,
    172-
    ReadResourceRequestParams,
    173-
    )
    174-
    175125
    return await self.send_request(
    176-
    ClientRequest(
    177-
    ReadResourceRequest(
    126+
    types.ClientRequest(
    127+
    types.ReadResourceRequest(
    178128
    method="resources/read",
    179-
    params=ReadResourceRequestParams(uri=uri),
    129+
    params=types.ReadResourceRequestParams(uri=uri),
    180130
    )
    181131
    ),
    182-
    ReadResourceResult,
    132+
    types.ReadResourceResult,
    183133
    )
    184134

    185-
    async def subscribe_resource(self, uri: AnyUrl) -> EmptyResult:
    135+
    async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
    186136
    """Send a resources/subscribe request."""
    187-
    from mcp.types import (
    188-
    SubscribeRequest,
    189-
    SubscribeRequestParams,
    190-
    )
    191-
    192137
    return await self.send_request(
    193-
    ClientRequest(
    194-
    SubscribeRequest(
    138+
    types.ClientRequest(
    139+
    types.SubscribeRequest(
    195140
    method="resources/subscribe",
    196-
    params=SubscribeRequestParams(uri=uri),
    141+
    params=types.SubscribeRequestParams(uri=uri),
    197142
    )
    198143
    ),
    199-
    EmptyResult,
    144+
    types.EmptyResult,
    200145
    )
    201146

    202-
    async def unsubscribe_resource(self, uri: AnyUrl) -> EmptyResult:
    147+
    async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
    203148
    """Send a resources/unsubscribe request."""
    204-
    from mcp.types import (
    205-
    UnsubscribeRequest,
    206-
    UnsubscribeRequestParams,
    207-
    )
    208-
    209149
    return await self.send_request(
    210-
    ClientRequest(
    211-
    UnsubscribeRequest(
    150+
    types.ClientRequest(
    151+
    types.UnsubscribeRequest(
    212152
    method="resources/unsubscribe",
    213-
    params=UnsubscribeRequestParams(uri=uri),
    153+
    params=types.UnsubscribeRequestParams(uri=uri),
    214154
    )
    215155
    ),
    216-
    EmptyResult,
    156+
    types.EmptyResult,
    217157
    )
    218158

    219159
    async def call_tool(
    220160
    self, name: str, arguments: dict | None = None
    221-
    ) -> CallToolResult:
    161+
    ) -> types.CallToolResult:
    222162
    """Send a tools/call request."""
    223-
    from mcp.types import (
    224-
    CallToolRequest,
    225-
    CallToolRequestParams,
    226-
    )
    227-
    228163
    return await self.send_request(
    229-
    ClientRequest(
    230-
    CallToolRequest(
    164+
    types.ClientRequest(
    165+
    types.CallToolRequest(
    231166
    method="tools/call",
    232-
    params=CallToolRequestParams(name=name, arguments=arguments),
    167+
    params=types.CallToolRequestParams(name=name, arguments=arguments),
    233168
    )
    234169
    ),
    235-
    CallToolResult,
    170+
    types.CallToolResult,
    236171
    )
    237172

    238-
    async def list_prompts(self) -> ListPromptsResult:
    173+
    async def list_prompts(self) -> types.ListPromptsResult:
    239174
    """Send a prompts/list request."""
    240-
    from mcp.types import ListPromptsRequest
    241-
    242175
    return await self.send_request(
    243-
    ClientRequest(
    244-
    ListPromptsRequest(
    176+
    types.ClientRequest(
    177+
    types.ListPromptsRequest(
    245178
    method="prompts/list",
    246179
    )
    247180
    ),
    248-
    ListPromptsResult,
    181+
    types.ListPromptsResult,
    249182
    )
    250183

    251184
    async def get_prompt(
    252185
    self, name: str, arguments: dict[str, str] | None = None
    253-
    ) -> GetPromptResult:
    186+
    ) -> types.GetPromptResult:
    254187
    """Send a prompts/get request."""
    255-
    from mcp.types import GetPromptRequest, GetPromptRequestParams
    256-
    257188
    return await self.send_request(
    258-
    ClientRequest(
    259-
    GetPromptRequest(
    189+
    types.ClientRequest(
    190+
    types.GetPromptRequest(
    260191
    method="prompts/get",
    261-
    params=GetPromptRequestParams(name=name, arguments=arguments),
    192+
    params=types.GetPromptRequestParams(name=name, arguments=arguments),
    262193
    )
    263194
    ),
    264-
    GetPromptResult,
    195+
    types.GetPromptResult,
    265196
    )
    266197

    267198
    async def complete(
    268-
    self, ref: ResourceReference | PromptReference, argument: dict
    269-
    ) -> CompleteResult:
    199+
    self, ref: types.ResourceReference | types.PromptReference, argument: dict
    200+
    ) -> types.CompleteResult:
    270201
    """Send a completion/complete request."""
    271-
    from mcp.types import (
    272-
    CompleteRequest,
    273-
    CompleteRequestParams,
    274-
    CompletionArgument,
    275-
    )
    276-
    277202
    return await self.send_request(
    278-
    ClientRequest(
    279-
    CompleteRequest(
    203+
    types.ClientRequest(
    204+
    types.CompleteRequest(
    280205
    method="completion/complete",
    281-
    params=CompleteRequestParams(
    206+
    params=types.CompleteRequestParams(
    282207
    ref=ref,
    283-
    argument=CompletionArgument(**argument),
    208+
    argument=types.CompletionArgument(**argument),
    284209
    ),
    285210
    )
    286211
    ),
    287-
    CompleteResult,
    212+
    types.CompleteResult,
    288213
    )
    289214

    290-
    async def list_tools(self) -> ListToolsResult:
    215+
    async def list_tools(self) -> types.ListToolsResult:
    291216
    """Send a tools/list request."""
    292-
    from mcp.types import ListToolsRequest
    293-
    294217
    return await self.send_request(
    295-
    ClientRequest(
    296-
    ListToolsRequest(
    218+
    types.ClientRequest(
    219+
    types.ListToolsRequest(
    297220
    method="tools/list",
    298221
    )
    299222
    ),
    300-
    ListToolsResult,
    223+
    types.ListToolsResult,
    301224
    )
    302225

    303226
    async def send_roots_list_changed(self) -> None:
    304227
    """Send a roots/list_changed notification."""
    305-
    from mcp.types import RootsListChangedNotification
    306-
    307228
    await self.send_notification(
    308-
    ClientNotification(
    309-
    RootsListChangedNotification(
    229+
    types.ClientNotification(
    230+
    types.RootsListChangedNotification(
    310231
    method="notifications/roots/list_changed",
    311232
    )
    312233
    )

    0 commit comments

    Comments
     (0)
    0