Skip to content

Should we enforce the validation for method field in xxxRequest? #75

@zhongyi51

Description

@zhongyi51

In this repository, the ClientRequest enum is marked as #[serde(untagged)], so serde will deserialize this enum by trying each variant in order:

#[derive(:: serde :: Deserialize, :: serde :: Serialize, Clone, Debug)]
#[serde(untagged)]
pub enum ClientRequest {
    InitializeRequest(InitializeRequest),
    PingRequest(PingRequest),
    ListResourcesRequest(ListResourcesRequest),
    ListResourceTemplatesRequest(ListResourceTemplatesRequest),
    ReadResourceRequest(ReadResourceRequest),
    SubscribeRequest(SubscribeRequest),
    UnsubscribeRequest(UnsubscribeRequest),
    ListPromptsRequest(ListPromptsRequest),
    GetPromptRequest(GetPromptRequest),
    ListToolsRequest(ListToolsRequest),
    CallToolRequest(CallToolRequest),
    SetLevelRequest(SetLevelRequest),
    CompleteRequest(CompleteRequest),
}

According to the doc, untagged enum will try be deserialized orderly until the first succeed.

In jsonschema, the field method in each xxxRequest are marked as constant value, but this constraint is not applied during the deserializing process, which may lead to ambiguity because illegal method name can be set into the field by deserializing from json.

So I think we should adding a custom deserializer to the method field:

#[derive(:: serde :: Deserialize, :: serde :: Serialize, Clone, Debug)]
pub struct ListResourcesRequest {
    #[serde(deserialize_with = "deserialize_list_resources_method")]
    pub method: ::std::string::String,
    #[serde(default, skip_serializing_if = "::std::option::Option::is_none")]
    pub params: ::std::option::Option<ListResourcesRequestParams>,
}

/// other code...

fn deserialize_list_resources_method<'de, D>(deserializer: D) -> std::result::Result<String, D::Error>
where
    D: Deserializer<'de>,
{
    deserialize_const_method(deserializer, mcp_methods::LIST_RESOURCES)
}

/// other code...

fn deserialize_const_method<'de, D>(
    deserializer: D,
    expected: &'static str,
) -> std::result::Result<String, D::Error>
where
    D: Deserializer<'de>,
{
    let value = String::deserialize(deserializer)?;
    if value == expected {
        Ok(value)
    } else {
        Err(serde::de::Error::custom(format!(
            "Expected method '{}', but got '{}'",
            expected, value
        )))
    }
}

And the code below will be ok:

fn handle_request(request: &mut Request) {
    let call_req: ClientRequest = serde_json::from_reader(request.as_reader())?;
    let res: ServerResult = match call_req {
        ClientRequest::InitializeRequest(InitializeRequest { method, params }) => {
              /// deal init request
        }
        _ => {
            bail!("unsupported request")
        }
    };
}

If this approach is ok, I would be happy to implement it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions