Skip to content

Commit b469c57

Browse files
author
Tapan Chugh
committed
SEP: Elicitation Enum Schema Improvements and Standards Compliance
1 parent ef4e167 commit b469c57

File tree

3 files changed

+134
-8
lines changed

3 files changed

+134
-8
lines changed

src/mcp/server/elicitation.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import types
6+
from collections.abc import Sequence
67
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
78

89
from pydantic import BaseModel
@@ -46,11 +47,22 @@ def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
4647
if not _is_primitive_field(field_info):
4748
raise TypeError(
4849
f"Elicitation schema field '{field_name}' must be a primitive type "
49-
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
50-
f"Complex types like lists, dicts, or nested models are not allowed."
50+
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
51+
f"or Optional of these types. Nested models and complex types are not allowed."
5152
)
5253

5354

55+
def _is_string_sequence(annotation: type) -> bool:
56+
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
57+
origin = get_origin(annotation)
58+
# Check if it's a sequence-like type with str elements
59+
if origin and issubclass(origin, Sequence):
60+
args = get_args(annotation)
61+
# Should have single str type arg
62+
return len(args) == 1 and args[0] is str
63+
return False
64+
65+
5466
def _is_primitive_field(field_info: FieldInfo) -> bool:
5567
"""Check if a field is a primitive type allowed in elicitation schemas."""
5668
annotation = field_info.annotation
@@ -63,12 +75,21 @@ def _is_primitive_field(field_info: FieldInfo) -> bool:
6375
if annotation in _ELICITATION_PRIMITIVE_TYPES:
6476
return True
6577

78+
# Handle string sequences for multi-select enums
79+
if annotation is not None and _is_string_sequence(annotation):
80+
return True
81+
6682
# Handle Union types
6783
origin = get_origin(annotation)
6884
if origin is Union or origin is types.UnionType:
6985
args = get_args(annotation)
70-
# All args must be primitive types or None
71-
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)
86+
# All args must be primitive types, None, or string sequences
87+
return all(
88+
arg is types.NoneType
89+
or arg in _ELICITATION_PRIMITIVE_TYPES
90+
or (arg is not None and _is_string_sequence(arg))
91+
for arg in args
92+
)
7293

7394
return False
7495

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,7 @@ class ElicitResult(Result):
12741274
- "cancel": User dismissed without making an explicit choice
12751275
"""
12761276

1277-
content: dict[str, str | int | float | bool | None] | None = None
1277+
content: dict[str, str | int | float | bool | list[str] | None] | None = None
12781278
"""
12791279
The submitted form data, only present when action is "accept".
12801280
Contains values matching the requested schema.

tests/server/fastmcp/test_elicitation.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ async def tool(ctx: Context) -> str:
112112

113113
# Test cases for invalid schemas
114114
class InvalidListSchema(BaseModel):
115-
names: list[str] = Field(description="List of names")
115+
numbers: list[int] = Field(description="List of numbers")
116116

117117
class NestedModel(BaseModel):
118118
value: str
@@ -133,7 +133,7 @@ async def elicitation_callback(context, params):
133133
await client_session.initialize()
134134

135135
# Test both invalid schemas
136-
for tool_name, field_name in [("invalid_list", "names"), ("nested_model", "nested")]:
136+
for tool_name, field_name in [("invalid_list", "numbers"), ("nested_model", "nested")]:
137137
result = await client_session.call_tool(tool_name, {})
138138
assert len(result.content) == 1
139139
assert isinstance(result.content[0], TextContent)
@@ -191,7 +191,7 @@ async def callback(context, params):
191191
# Test invalid optional field
192192
class InvalidOptionalSchema(BaseModel):
193193
name: str = Field(description="Name")
194-
optional_list: list[str] | None = Field(default=None, description="Invalid optional list")
194+
optional_list: list[int] | None = Field(default=None, description="Invalid optional list")
195195

196196
@mcp.tool(description="Tool with invalid optional field")
197197
async def invalid_optional_tool(ctx: Context) -> str:
@@ -208,3 +208,108 @@ async def invalid_optional_tool(ctx: Context) -> str:
208208
{},
209209
text_contains=["Validation failed:", "optional_list"],
210210
)
211+
212+
# Test valid list[str] for multi-select enum
213+
class ValidMultiSelectSchema(BaseModel):
214+
name: str = Field(description="Name")
215+
tags: list[str] = Field(description="Tags")
216+
217+
@mcp.tool(description="Tool with valid list[str] field")
218+
async def valid_multiselect_tool(ctx: Context) -> str:
219+
result = await ctx.elicit(message="Please provide tags", schema=ValidMultiSelectSchema)
220+
if result.action == "accept" and result.data:
221+
return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}"
222+
return f"User {result.action}"
223+
224+
async def multiselect_callback(context, params):
225+
if "Please provide tags" in params.message:
226+
return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]})
227+
return ElicitResult(action="decline")
228+
229+
await call_tool_and_assert(mcp, multiselect_callback, "valid_multiselect_tool", {}, "Name: Test, Tags: tag1, tag2")
230+
231+
232+
@pytest.mark.anyio
233+
async def test_elicitation_with_enum_titles():
234+
"""Test elicitation with enum schemas using oneOf/anyOf for titles."""
235+
mcp = FastMCP(name="ColorPreferencesApp")
236+
237+
# Test single-select with titles using oneOf
238+
class FavoriteColorSchema(BaseModel):
239+
user_name: str = Field(description="Your name")
240+
favorite_color: str = Field(
241+
description="Select your favorite color",
242+
json_schema_extra={
243+
"oneOf": [
244+
{"const": "red", "title": "Red"},
245+
{"const": "green", "title": "Green"},
246+
{"const": "blue", "title": "Blue"},
247+
{"const": "yellow", "title": "Yellow"},
248+
]
249+
},
250+
)
251+
252+
@mcp.tool(description="Single color selection")
253+
async def select_favorite_color(ctx: Context) -> str:
254+
result = await ctx.elicit(message="Select your favorite color", schema=FavoriteColorSchema)
255+
if result.action == "accept" and result.data:
256+
return f"User: {result.data.user_name}, Favorite: {result.data.favorite_color}"
257+
return f"User {result.action}"
258+
259+
# Test multi-select with titles using anyOf
260+
class FavoriteColorsSchema(BaseModel):
261+
user_name: str = Field(description="Your name")
262+
favorite_colors: list[str] = Field(
263+
description="Select your favorite colors",
264+
json_schema_extra={
265+
"items": {
266+
"anyOf": [
267+
{"const": "red", "title": "Red"},
268+
{"const": "green", "title": "Green"},
269+
{"const": "blue", "title": "Blue"},
270+
{"const": "yellow", "title": "Yellow"},
271+
]
272+
}
273+
},
274+
)
275+
276+
@mcp.tool(description="Multiple color selection")
277+
async def select_favorite_colors(ctx: Context) -> str:
278+
result = await ctx.elicit(message="Select your favorite colors", schema=FavoriteColorsSchema)
279+
if result.action == "accept" and result.data:
280+
return f"User: {result.data.user_name}, Colors: {', '.join(result.data.favorite_colors)}"
281+
return f"User {result.action}"
282+
283+
# Test deprecated enumNames format
284+
class DeprecatedColorSchema(BaseModel):
285+
user_name: str = Field(description="Your name")
286+
color: str = Field(
287+
description="Select a color",
288+
json_schema_extra={"enum": ["red", "green", "blue"], "enumNames": ["Red", "Green", "Blue"]},
289+
)
290+
291+
@mcp.tool(description="Deprecated enum format")
292+
async def select_color_deprecated(ctx: Context) -> str:
293+
result = await ctx.elicit(message="Select a color (deprecated format)", schema=DeprecatedColorSchema)
294+
if result.action == "accept" and result.data:
295+
return f"User: {result.data.user_name}, Color: {result.data.color}"
296+
return f"User {result.action}"
297+
298+
async def enum_callback(context, params):
299+
if "colors" in params.message and "deprecated" not in params.message:
300+
return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]})
301+
elif "color" in params.message:
302+
if "deprecated" in params.message:
303+
return ElicitResult(action="accept", content={"user_name": "Charlie", "color": "green"})
304+
else:
305+
return ElicitResult(action="accept", content={"user_name": "Alice", "favorite_color": "blue"})
306+
return ElicitResult(action="decline")
307+
308+
# Test single-select with titles
309+
await call_tool_and_assert(mcp, enum_callback, "select_favorite_color", {}, "User: Alice, Favorite: blue")
310+
311+
# Test multi-select with titles
312+
await call_tool_and_assert(mcp, enum_callback, "select_favorite_colors", {}, "User: Bob, Colors: red, green")
313+
314+
# Test deprecated enumNames format
315+
await call_tool_and_assert(mcp, enum_callback, "select_color_deprecated", {}, "User: Charlie, Color: green")

0 commit comments

Comments
 (0)