8000 Add support for Elicitation (#625) · error05/python-sdk@69e6572 · GitHub
[go: up one dir, main page]

Skip to content

Commit 69e6572

Browse files
ihrprdsp-ant
andauthored
Add support for Elicitation (modelcontextprotocol#625)
Co-authored-by: David Soria Parra <davidsp@anthropic.com>
1 parent a3bcabd commit 69e6572

File tree

9 files changed

+643
-8
lines changed

9 files changed

+643
-8
lines changed

README.md

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
- [Images](#images)
3232
- [Context](#context)
3333
- [Completions](#completions)
34+
- [Elicitation](#elicitation)
35+
- [Authentication](#authentication)
3436
- [Running Your Server](#running-your-server)
3537
- [Development Mode](#development-mode)
3638
- [Claude Desktop Integration](#claude-desktop-integration)
@@ -74,7 +76,7 @@ The Model Context Protocol allows applications to provide context for LLMs in a
7476

7577
### Adding MCP to your python project
7678

77-
We recommend using [uv](https://docs.astral.sh/uv/) to manage your Python projects.
79+
We recommend using [uv](https://docs.astral.sh/uv/) to manage your Python projects.
7880

7981
If you haven't created a uv-managed project yet, create one:
8082

@@ -372,6 +374,50 @@ async def handle_completion(
372374
return Completion(values=filtered)
373375
return None
374376
```
377+
### Elicitation
378+
379+
Request additional information from users during tool execution:
380+
381+
```python
382+
from mcp.server.fastmcp import FastMCP, Context
383+
from mcp.server.elicitation import (
384+
AcceptedElicitation,
385+
DeclinedElicitation,
386+
CancelledElicitation,
387+
)
388+
from pydantic import BaseModel, Field
389+
390+
mcp = FastMCP("Booking System")
391+
392+
393+
@mcp.tool()
394+
async def book_table(date: str, party_size: int, ctx: Context) -> str:
395+
"""Book a table with confirmation"""
396+
397+
# Schema must only contain primitive types (str, int, float, bool)
398+
class ConfirmBooking(BaseModel):
399+
confirm: bool = Field(description="Confirm booking?")
400+
notes: str = Field(default="", description="Special requests")
401+
402+
result = await ctx.elicit(
403+
message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking
404+
)
405+
406+
match result:
407+
case AcceptedElicitation(data=data):
408+
if data.confirm:
409+
return f"Booked! Notes: {data.notes or 'None'}"
410+
return "Booking cancelled"
411+
case DeclinedElicitation():
412+
return "Booking declined"
413+
case CancelledElicitation():
414+
return "Booking cancelled"
415+
```
416+
417+
The `elicit()` method returns an `ElicitationResult` with:
418+
- `action`: "accept", "decline", or "cancel"
419+
- `data`: The validated response (only when accepted)
420+
- `validation_error`: Any validation error message
375421

376422
### Authentication
377423

src/mcp/client/session.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ async def __call__(
2222
) -> types.CreateMessageResult | types.ErrorData: ...
2323

2424

25+
class ElicitationFnT(Protocol):
26+
async def __call__(
27+
self,
28+
context: RequestContext["ClientSession", Any],
29+
params: types.ElicitRequestParams,
30+
) -> types.ElicitResult | types.ErrorData: ...
31+
32+
2533
class ListRootsFnT(Protocol):
2634
async def __call__(
2735
self, context: RequestContext["ClientSession", Any]
@@ -58,6 +66,16 @@ async def _default_sampling_callback(
5866
)
5967

6068

69+
async def _default_elicitation_callback(
70+
context: RequestContext["ClientSession", Any],
71+
params: types.ElicitRequestParams,
72+
) -> types.ElicitResult | types.ErrorData:
73+
return types.ErrorData(
74+
code=types.INVALID_REQUEST,
75+
message="Elicitation not supported",
76+
)
77+
78+
6179
async def _default_list_roots_callback(
6280
context: RequestContext["ClientSession", Any],
6381
) -> types.ListRootsResult | types.ErrorData:
@@ -91,6 +109,7 @@ def __init__(
91109
write_stream: MemoryObjectSendStream[SessionMessage],
92110
read_timeout_seconds: timedelta | None = None,
93111
sampling_callback: SamplingFnT | None = None,
112+
elicitation_callback: ElicitationFnT | None = None,
94113
list_roots_callback: ListRootsFnT | None = None,
95114
logging_callback: LoggingFnT | None = None,
96115
message_handler: MessageHandlerFnT | None = None,
@@ -105,12 +124,16 @@ def __init__(
105124
)
106125
self._client_info = client_info or DEFAULT_CLIENT_INFO
107126
self._sampling_callback = sampling_callback or _default_sampling_callback
F438 127+
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
108128
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
109129
self._logging_callback = logging_callback or _default_logging_callback
110130
self._message_handler = message_handler or _default_message_handler
111131

112132
async def initialize(self) -> types.InitializeResult:
113133
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
134+
elicitation = (
135+
types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None
136+
)
114137
roots = (
115138
# TODO: Should this be based on whether we
116139
# _will_ send notifications, or only whether
@@ -128,6 +151,7 @@ async def initialize(self) -> types.InitializeResult:
128151
protocolVersion=types.LATEST_PROTOCOL_VERSION,
129152
capabilities=types.ClientCapabilities(
130153
sampling=sampling,
154+
elicitation=elicitation,
131155
experimental=None,
132156
roots=roots,
133157
),
@@ -362,6 +386,12 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
362386
client_response = ClientResponse.validate_python(response)
363387
await responder.respond(client_response)
364388

389+
case types.ElicitRequest(params=params):
390+
with responder:
391+
response = await self._elicitation_callback(ctx, params)
392+
client_response = ClientResponse.validate_python(response)
393+
await responder.respond(client_response)
394+
365395
case types.ListRootsRequest():
366396
with responder:
367397
response = await self._list_roots_callback(ctx)

src/mcp/server/elicitation.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Elicitation utilities for MCP servers."""
2+
3+
from __future__ import annotations
4+
5+
import types
6+
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
7+
8+
from pydantic import BaseModel
9+
from pydantic.fields import FieldInfo
10+
11+
from mcp.server.session import ServerSession
12+
from mcp.types import RequestId
13+
14+
ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
15+
16+
17+
class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]):
18+
"""Result when user accepts the elicitation."""
19+
20+
action: Literal["accept"] = "accept"
21+
data: ElicitSchemaModelT
22+
23+
24+
class DeclinedElicitation(BaseModel):
25+
"""Result when user declines the elicitation."""
26+
27+
action: Literal["decline"] = "decline"
28+
29+
30+
class CancelledElicitation(BaseModel):
31+
"""Result when user cancels the elicitation."""
32+
33+
action: Literal["cancel"] = "cancel"
34+
35+
36+
ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation
37+
38+
39+
# Primitive types allowed in elicitation schemas
40+
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
41+
42+
43+
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
44+
"""Validate that a Pydantic model only contains primitive field types."""
45+
for field_name, field_info in schema.model_fields.items():
46+
if not _is_primitive_field(field_info):
47+
raise TypeError(
48+
f"Elicitation schema field '{field_name}' must be a primitive type "
49+
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
50+
f"Complex types like lists, dicts, or nested models are not allowed."
51+
)
52+
53+
54+
def _is_primitive_field(field_info: FieldInfo) -> bool:
55+
"""Check if a field is a primitive type allowed in elicitation schemas."""
56+
annotation = field_info.annotation
57+
58+
# Handle None type
59+
if annotation is types.NoneType:
60+
return True
61+
62+
# Handle basic primitive types
63+
if annotation in _ELICITATION_PRIMITIVE_TYPES:
64+
return True
65+
66+
# Handle Union types
67+
origin = get_origin(annotation)
68+
if origin is Union or origin is types.UnionType:
69+
args = get_args(annotation)
70+
# All args must be primitive types or None
71+
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
72+
73+
return False
74+
75+
76+
async def elicit_with_validation(
77+
session: ServerSession,
78+
message: str,
79+
schema: type[ElicitSchemaModelT],
80+
related_request_id: RequestId | None = None,
81+
) -> ElicitationResult[ElicitSchemaModelT]:
82+
"""Elicit information from the client/user with schema validation.
83+
84+
This method can be used to interactively ask for additional information from the
85+
client within a tool's execution. The client might display the message to the
86+
user and collect a response according to the provided schema. Or in case a
87+
client is an agent, it might decide how to handle the elicitation -- either by asking
88+
the user or automatically generating a response.
89+
"""
90+
# Validate that schema only contains primitive types and fail loudly if not
91+
_validate_elicitation_schema(schema)
92+
93+
json_schema = schema.model_json_schema()
94+
95+
result = await session.elicit(
96+
message=message,
97+
requestedSchema=json_schema,
98+
related_request_id=related_request_id,
99+
)
100+
101+
if result.action == "accept" and result.content:
102+
# Validate and parse the content using the schema
103+
validated_data = schema.model_validate(result.content)
104+
return AcceptedElicitation(data=validated_data)
105+
elif result.action == "decline":
106+
return DeclinedElicitation()
107+
elif result.action == "cancel":
108+
return CancelledElicitation()
109+
else:
110+
# This should never happen, but handle it just in case
111+
raise ValueError(f"Unexpected elicitation action: {result.action}")

src/mcp/server/fastmcp/server.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from mcp.server.auth.settings import (
3535
AuthSettings,
3636
)
37+
from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation
3738
from mcp.server.fastmcp.exceptions import ResourceError
3839
from mcp.server.fastmcp.prompts import Prompt, PromptManager
3940
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
@@ -972,6 +973,37 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
972973
assert self._fastmcp is not None, "Context is not available outside of a request"
973974
return await self._fastmcp.read_resource(uri)
974975

976+
async def elicit(
977+
self,
978+
message: str,
979+
schema: type[ElicitSchemaModelT],
980+
) -> ElicitationResult[ElicitSchemaModelT]:
981+
"""Elicit information from the client/user.
982+
983+
This method can be used to interactively ask for additional information from the
984+
client within a tool's execution. The client might display the message to the
985+
user and collect a response according to the provided schema. Or in case a
986+
client is an agent, it might decide how to handle the elicitation -- either by asking
987+
the user or automatically generating a response.
988+
989+
Args:
990+
schema: A Pydantic model class defining the expected response structure, according to the specification,
991+
only primive types are allowed.
992+
message: Optional message to present to the user. If not provided, will use
993+
a default message based on the schema
994+
995+
Returns:
996+
An ElicitationResult containing the action taken and the data if accepted
997+
998+
Note:
999+
Check the result.action to determine if the user accepted, declined, or cancelled.
1000+
The result.data will only be populated if action is "accept" and validation succeeded.
1001+
"""
1002+
1003+
return await elicit_with_validation(
1004+
session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id
1005+
)
1006+
9751007
async def log(
9761008
self,
9771009
level: Literal["debug", "info", "warning", "error"],

src/mcp/server/session.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
121121
if client_caps.sampling is None:
122122
return False
123123

124+
if capability.elicitation is not None:
125+
if client_caps.elicitation is None:
126+
return False
127+
124128
if capability.experimental is not None:
125129
if client_caps.experimental is None:
126130
return False
@@ -251,6 +255,35 @@ async def list_roots(self) -> types.ListRootsResult:
251255
types.ListRootsResult,
252256
)
253257

258+
async def elicit(
259+
self,
260+
message: str,
261+
requestedSchema: types.ElicitRequestedSchema,
262+
related_request_id: types.RequestId | None = None,
263+
) -> types.ElicitResult:
264+
"""Send an elicitation/create request.
265+
266+
Args:
267+
message: The message to present to the user
268+
requestedSchema: Schema defining the expected response structure
269+
270+
Returns:
271+
The client's response
272+
"""
273+
return await self.send_request(
274+
types.ServerRequest(
275+
types.ElicitRequest(
276+
method="elicitation/create",
277+
params=types.ElicitRequestParams(
278+
message=message,
279+
requestedSchema=requestedSchema,
280+
),
281+
)
282+
),
283+
types.ElicitResult,
284+
metadata=ServerMessageMetadata(related_request_id=related_request_id),
285+
)
286+
254287
async def send_ping(self) -> types.EmptyResult:
255288
"""Send a ping request."""
256289
return await self.send_request(

src/mcp/shared/memory.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1212

1313
import mcp.types as types
14-
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
14+
from mcp.client.session import (
15+
ClientSession,
16+
ElicitationFnT,
17+
ListRootsFnT,
18+
LoggingFnT,
19+
MessageHandlerFnT,
20+
SamplingFnT,
21+
)
1522
from mcp.server import Server
1623
from mcp.shared.message import SessionMessage
1724

@@ -53,6 +60,7 @@ async def create_connected_server_and_client_session(
5360
message_handler: MessageHandlerFnT | None = None,
5461
client_info: types.Implementation | None = None,
5562
raise_exceptions: bool = False,
63+
elicitation_callback: ElicitationFnT | None = None,
5664
) -> AsyncGenerator[ClientSession, None]:
5765
"""Creates a ClientSession that is connected to a running MCP server."""
5866
async with create_client_server_memory_streams() as (
@@ -83,6 +91,7 @@ async def create_connected_server_and_client_session(
8391
logging_callback=logging_callback,
8492
message_handler=message_handler,
8593
client_info=client_info,
94+
elicitation_callback=elicitation_callback,
8695
) as client_session:
8796
await client_session.initialize()
8897
yield client_session

0 commit comments

Comments
 (0)
0