Skip to content

Commit 9eea409

Browse files
author
Tapan Chugh
committed
feat: Use AnyUrl type for URI fields and update managers
- Change Tool and Prompt classes to use AnyUrl for uri field - Update all manager get/call methods to accept AnyUrl | str - Add method overloads for better type hints - Update tests to use AnyUrl objects - Update filter_by_uri_paths to use Sequence for covariance
1 parent e4614ba commit 9eea409

File tree

14 files changed

+240
-1031
lines changed

14 files changed

+240
-1031
lines changed

src/mcp/.types.py.~undo-tree~

Lines changed: 0 additions & 881 deletions
This file was deleted.

src/mcp/server/fastmcp/prompts/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Literal
66

77
import pydantic_core
8-
from pydantic import BaseModel, Field, TypeAdapter, validate_call
8+
from pydantic import AnyUrl, BaseModel, Field, TypeAdapter, validate_call
99

1010
from mcp.types import PROMPT_SCHEME, ContentBlock, TextContent
1111

@@ -58,7 +58,7 @@ class Prompt(BaseModel):
5858
"""A prompt template that can be rendered with parameters."""
5959

6060
name: str = Field(description="Name of the prompt")
61-
uri: str = Field(description="URI of the prompt")
61+
uri: AnyUrl = Field(description="URI of the prompt")
6262
title: str | None = Field(None, description="Human-readable title of the prompt")
6363
description: str | None = Field(None, description="Description of what the prompt does")
6464
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
@@ -67,7 +67,7 @@ class Prompt(BaseModel):
6767
def __init__(self, **data: Any) -> None:
6868
"""Initialize Prompt, generating URI from name if not provided."""
6969
if "uri" not in data and "name" in data:
70-
data["uri"] = f"{PROMPT_SCHEME}/{data['name']}"
70+
data["uri"] = AnyUrl(f"{PROMPT_SCHEME}/{data['name']}")
7171
super().__init__(**data)
7272

7373
@classmethod

src/mcp/server/fastmcp/prompts/manager.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Prompt management functionality."""
22

3-
from typing import Any
3+
from typing import Any, overload
4+
5+
from pydantic import AnyUrl
46

57
from mcp.server.fastmcp.prompts.base import Message, Prompt
68
from mcp.server.fastmcp.uri_utils import filter_by_uri_paths, normalize_to_prompt_uri
@@ -20,15 +22,28 @@ def _normalize_to_uri(self, name_or_uri: str) -> str:
2022
"""Convert name to URI if needed."""
2123
return normalize_to_prompt_uri(name_or_uri)
2224

23-
def get_prompt(self, name: str) -> Prompt | None:
25+
@overload
26+
def get_prompt(self, name_or_uri: str) -> Prompt | None:
27+
"""Get prompt by name."""
28+
...
29+
30+
@overload
31+
def get_prompt(self, name_or_uri: AnyUrl) -> Prompt | None:
32+
"""Get prompt by URI."""
33+
...
34+
35+
def get_prompt(self, name_or_uri: AnyUrl | str) -> Prompt | None:
2436
"""Get prompt by name or URI."""
25-
uri = self._normalize_to_uri(name)
37+
if isinstance(name_or_uri, AnyUrl):
38+
return self._prompts.get(str(name_or_uri))
39+
uri = self._normalize_to_uri(name_or_uri)
2640
return self._prompts.get(uri)
2741

28-
def list_prompts(self, uri_paths: list[str] | None = None) -> list[Prompt]:
42+
def list_prompts(self, uri_paths: list[AnyUrl] | None = None) -> list[Prompt]:
2943
"""List all registered prompts, optionally filtered by URI paths."""
3044
prompts = list(self._prompts.values())
31-
prompts = filter_by_uri_paths(prompts, uri_paths, lambda p: p.uri)
45+
if uri_paths:
46+
prompts = filter_by_uri_paths(prompts, uri_paths)
3247
logger.debug("Listing prompts", extra={"count": len(prompts), "uri_paths": uri_paths})
3348
return prompts
3449

@@ -40,19 +55,29 @@ def add_prompt(
4055
logger.debug(f"Adding prompt: {prompt.name} with URI: {prompt.uri}")
4156

4257
# Check for duplicates
43-
existing = self._prompts.get(prompt.uri)
58+
existing = self._prompts.get(str(prompt.uri))
4459
if existing:
4560
if self.warn_on_duplicate_prompts:
4661
logger.warning(f"Prompt already exists: {prompt.uri}")
4762
return existing
4863

49-
self._prompts[prompt.uri] = prompt
64+
self._prompts[str(prompt.uri)] = prompt
5065
return prompt
5166

52-
async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]:
67+
@overload
68+
async def render_prompt(self, name_or_uri: str, arguments: dict[str, Any] | None = None) -> list[Message]:
5369
"""Render a prompt by name with arguments."""
54-
prompt = self.get_prompt(name)
70+
...
71+
72+
@overload
73+
async def render_prompt(self, name_or_uri: AnyUrl, arguments: dict[str, Any] | None = None) -> list[Message]:
74+
"""Render a prompt by URI with arguments."""
75+
...
76+
77+
async def render_prompt(self, name_or_uri: AnyUrl | str, arguments: dict[str, Any] | None = None) -> list[Message]:
78+
"""Render a prompt by name or URI with arguments."""
79+
prompt = self.get_prompt(name_or_uri)
5580
if not prompt:
56-
raise ValueError(f"Unknown prompt: {name}")
81+
raise ValueError(f"Unknown prompt: {name_or_uri}")
5782

5883
return await prompt.render(arguments)

src/mcp/server/fastmcp/prompts/prompt_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Prompt management functionality."""
22

3+
from pydantic import AnyUrl
4+
35
from mcp.server.fastmcp.prompts.base import Prompt
46
from mcp.server.fastmcp.uri_utils import filter_by_uri_paths, normalize_to_prompt_uri
57
from mcp.server.fastmcp.utilities.logging import get_logger
@@ -34,9 +36,10 @@ def get_prompt(self, name: str) -> Prompt | None:
3436
uri = self._normalize_to_uri(name)
3537
return self._prompts.get(uri)
3638

37-
def list_prompts(self, uri_paths: list[str] | None = None) -> list[Prompt]:
39+
def list_prompts(self, uri_paths: list[AnyUrl] | None = None) -> list[Prompt]:
3840
"""List all registered prompts, optionally filtered by URI paths."""
3941
prompts = list(self._prompts.values())
40-
prompts = filter_by_uri_paths(prompts, uri_paths, lambda p: p.uri)
42+
if uri_paths:
43+
prompts = filter_by_uri_paths(prompts, uri_paths)
4144
logger.debug("Listing prompts", extra={"count": len(prompts), "uri_paths": uri_paths})
4245
return prompts

src/mcp/server/fastmcp/resources/resource_manager.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,24 +87,26 @@ async def get_resource(self, uri: AnyUrl | str) -> Resource | None:
8787

8888
raise ValueError(f"Unknown resource: {uri}")
8989

90-
def list_resources(self, uri_paths: list[str] | None = None) -> list[Resource]:
90+
def list_resources(self, uri_paths: list[AnyUrl] | None = None) -> list[Resource]:
9191
"""List all registered resources, optionally filtered by URI paths."""
9292
resources = list(self._resources.values())
93-
resources = filter_by_uri_paths(resources, uri_paths, lambda r: r.uri)
93+
if uri_paths:
94+
resources = filter_by_uri_paths(resources, uri_paths)
9495
logger.debug("Listing resources", extra={"count": len(resources), "uri_paths": uri_paths})
9596
return resources
9697

97-
def list_templates(self, uri_paths: list[str] | None = None) -> list[ResourceTemplate]:
98+
def list_templates(self, uri_paths: list[AnyUrl] | None = None) -> list[ResourceTemplate]:
9899
"""List all registered templates, optionally filtered by URI paths."""
99100
templates = list(self._templates.values())
100101
if uri_paths:
101102
filtered: list[ResourceTemplate] = []
102103
for template in templates:
103104
for prefix in uri_paths:
104105
# Ensure prefix ends with / for proper path matching
105-
if not prefix.endswith("/"):
106-
prefix = prefix + "/"
107-
if template.matches_prefix(prefix):
106+
prefix_str = str(prefix)
107+
if not prefix_str.endswith("/"):
108+
prefix_str = prefix_str + "/"
109+
if template.matches_prefix(prefix_str):
108110
filtered.append(template)
109111
break
110112
templates = filtered

src/mcp/server/fastmcp/tools/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from functools import cached_property
77
from typing import TYPE_CHECKING, Any, get_origin
88

9-
from pydantic import BaseModel, Field
9+
from pydantic import AnyUrl, BaseModel, Field
1010

1111
from mcp.server.fastmcp.exceptions import ToolError
1212
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
@@ -22,7 +22,7 @@ class Tool(BaseModel):
2222
"""Internal tool registration info."""
2323

2424
name: str = Field(description="Name of the tool")
25-
uri: str = Field(description="URI of the tool")
25+
uri: AnyUrl = Field(description="URI of the tool")
2626
title: str | None = Field(None, description="Human-readable title of the tool")
2727
description: str = Field(description="Description of what the tool does")
2828
fn: Callable[..., Any] = Field(exclude=True)
@@ -37,7 +37,7 @@ class Tool(BaseModel):
3737
def __init__(self, **data: Any) -> None:
3838
"""Initialize Tool, generating URI from name if not provided."""
3939
if "uri" not in data and "name" in data:
40-
data["uri"] = f"{TOOL_SCHEME}/{data['name']}"
40+
data["uri"] = AnyUrl(f"{TOOL_SCHEME}/{data['name']}")
4141
super().__init__(**data)
4242

4343
@cached_property

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations as _annotations
22

33
from collections.abc import Callable
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, overload
5+
6+
from pydantic import AnyUrl
57

68
from mcp.server.fastmcp.exceptions import ToolError
79
from mcp.server.fastmcp.tools.base import Tool
@@ -39,15 +41,28 @@ def _normalize_to_uri(self, name_or_uri: str) -> str:
3941
"""Convert name to URI if needed."""
4042
return normalize_to_tool_uri(name_or_uri)
4143

42-
def get_tool(self, name: str) -> Tool | None:
44+
@overload
45+
def get_tool(self, name_or_uri: str) -> Tool | None:
46+
"""Get tool by name."""
47+
...
48+
49+
@overload
50+
def get_tool(self, name_or_uri: AnyUrl) -> Tool | None:
51+
"""Get tool by URI."""
52+
...
53+
54+
def get_tool(self, name_or_uri: AnyUrl | str) -> Tool | None:
4355
"""Get tool by name or URI."""
44-
uri = self._normalize_to_uri(name)
56+
if isinstance(name_or_uri, AnyUrl):
57+
return self._tools.get(str(name_or_uri))
58+
uri = self._normalize_to_uri(name_or_uri)
4559
return self._tools.get(uri)
4660

47-
def list_tools(self, uri_paths: list[str] | None = None) -> list[Tool]:
61+
def list_tools(self, uri_paths: list[AnyUrl] | None = None) -> list[Tool]:
4862
"""List all registered tools, optionally filtered by URI paths."""
4963
tools = list(self._tools.values())
50-
tools = filter_by_uri_paths(tools, uri_paths, lambda t: t.uri)
64+
if uri_paths:
65+
tools = filter_by_uri_paths(tools, uri_paths)
5166
logger.debug("Listing tools", extra={"count": len(tools), "uri_paths": uri_paths})
5267
return tools
5368

@@ -77,16 +92,38 @@ def add_tool(
7792
self._tools[str(tool.uri)] = tool
7893
return tool
7994

95+
@overload
8096
async def call_tool(
8197
self,
82-
name: str,
98+
name_or_uri: str,
8399
arguments: dict[str, Any],
84100
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
85101
convert_result: bool = False,
86102
) -> Any:
87103
"""Call a tool by name with arguments."""
88-
tool = self.get_tool(name)
104+
...
105+
106+
@overload
107+
async def call_tool(
108+
self,
109+
name_or_uri: AnyUrl,
110+
arguments: dict[str, Any],
111+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
112+
convert_result: bool = False,
113+
) -> Any:
114+
"""Call a tool by URI with arguments."""
115+
...
116+
117+
async def call_tool(
118+
self,
119+
name_or_uri: AnyUrl | str,
120+
arguments: dict[str, Any],
121+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
122+
convert_result: bool = False,
123+
) -> Any:
124+
"""Call a tool by name or URI with arguments."""
125+
tool = self.get_tool(name_or_uri)
89126
if not tool:
90-
raise ToolError(f"Unknown tool: {name}")
127+
raise ToolError(f"Unknown tool: {name_or_uri}")
91128

92129
return await tool.run(arguments, context=context, convert_result=convert_result)

src/mcp/server/fastmcp/uri_utils.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Common URI utilities for FastMCP."""
22

3-
from collections.abc import Callable
4-
from typing import TypeVar
3+
from collections.abc import Sequence
4+
from typing import Protocol, TypeVar, runtime_checkable
55

66
from pydantic import AnyUrl
77

88
from mcp.types import PROMPT_SCHEME, TOOL_SCHEME
99

10-
T = TypeVar("T")
10+
T = TypeVar("T", bound="HasUri")
1111

1212

1313
def normalize_to_uri(name_or_uri: str, scheme: str) -> str:
@@ -35,34 +35,37 @@ def normalize_to_prompt_uri(name_or_uri: str) -> str:
3535
return normalize_to_uri(name_or_uri, PROMPT_SCHEME)
3636

3737

38-
def filter_by_uri_paths(
39-
items: list[T], uri_paths: list[str] | None, uri_getter: Callable[[T], AnyUrl | str]
40-
) -> list[T]:
38+
@runtime_checkable
39+
class HasUri(Protocol):
40+
"""Protocol for objects that have a URI attribute."""
41+
42+
uri: AnyUrl
43+
44+
45+
def filter_by_uri_paths(items: Sequence[T], uri_paths: Sequence[AnyUrl]) -> list[T]:
4146
"""Filter items by multiple URI path prefixes.
4247
4348
Args:
44-
items: List of items to filter
45-
uri_paths: Optional list of URI path prefixes to filter by. If None or empty, returns all items.
46-
uri_getter: Function to extract URI from an item
49+
items: List of items that have a 'uri' attribute
50+
uri_paths: List of URI path prefixes to filter by.
4751
4852
Returns:
4953
Filtered list of items matching any of the provided prefixes
5054
"""
51-
if not uri_paths:
52-
return items
5355

5456
# Filter items where the URI matches any of the prefixes
5557
filtered: list[T] = []
5658
for item in items:
57-
uri = str(uri_getter(item))
59+
uri = str(item.uri)
5860
for prefix in uri_paths:
59-
if uri.startswith(prefix):
61+
prefix_str = str(prefix)
62+
if uri.startswith(prefix_str):
6063
# If prefix ends with a separator, we already have a proper boundary
61-
if prefix.endswith(("/", "?", "#")):
64+
if prefix_str.endswith(("/", "?", "#")):
6265
filtered.append(item)
6366
break
6467
# Otherwise check if it's an exact match or if the next character is a separator
65-
elif len(uri) == len(prefix) or uri[len(prefix)] in ("/", "?", "#"):
68+
elif len(uri) == len(prefix_str) or uri[len(prefix_str)] in ("/", "?", "#"):
6669
filtered.append(item)
6770
break
6871

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Meta(BaseModel):
6363
class ListFilters(BaseModel):
6464
"""Filters for list operations."""
6565

66-
uri_paths: list[str] | None = None
66+
uri_paths: list[AnyUrl] | None = None
6767
"""Optional list of absolute URI path prefixes to filter results."""
6868

6969

0 commit comments

Comments
 (0)