Skip to content

Feature: Add completions context #369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
assertThat(initResult).isNotNull();

CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "code_review"),
new CompleteRequest.CompleteArgument("language", "py"));
new CompleteRequest.CompleteArgument("language", "py"), null);

CompleteResult result = mcpClient.completeCompletion(request);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ private McpSchema.CompleteRequest parseCompletionParams(Object object) {
Map<String, Object> params = (Map<String, Object>) object;
Map<String, Object> refMap = (Map<String, Object>) params.get("ref");
Map<String, Object> argMap = (Map<String, Object>) params.get("argument");
Map<String, Object> contextMap = (Map<String, Object>) params.get("context");

String refType = (String) refMap.get("type");

Expand All @@ -735,7 +736,13 @@ private McpSchema.CompleteRequest parseCompletionParams(Object object) {
McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName,
argValue);

return new McpSchema.CompleteRequest(ref, argument);
McpSchema.CompleteRequest.CompleteContext context = null;
if (contextMap != null) {
Map<String, String> arguments = (Map<String, String>) contextMap.get("arguments");
context = new McpSchema.CompleteRequest.CompleteContext(arguments);
}

return new McpSchema.CompleteRequest(ref, argument, context);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
*
* @author Christian Tzolov
* @author Luca Chang
* @author Surbhi Bansal
*/
public final class McpSchema {

Expand Down Expand Up @@ -1447,11 +1448,16 @@ public String identifier() {
@JsonIgnoreProperties(ignoreUnknown = true)
public record CompleteRequest(// @formatter:off
@JsonProperty("ref") McpSchema.CompleteReference ref,
@JsonProperty("argument") CompleteArgument argument) implements Request {
@JsonProperty("argument") CompleteArgument argument,
@JsonProperty("context") CompleteContext context) implements Request {

public record CompleteArgument(
@JsonProperty("name") String name,
@JsonProperty("value") String value) {
}

public record CompleteContext(
@JsonProperty("arguments") Map<String, String> arguments) {
}// @formatter:on
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
package io.modelcontextprotocol.server;

import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;

import org.apache.catalina.LifecycleException;
import org.apache.catalina.LifecycleState;
import org.apache.catalina.startup.Tomcat;
import static org.assertj.core.api.Assertions.assertThat;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;

import com.fasterxml.jackson.databind.ObjectMapper;

import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider;
import io.modelcontextprotocol.server.transport.TomcatTestUtil;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CompleteRequest;
import io.modelcontextprotocol.spec.McpSchema.CompleteResult;
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
import io.modelcontextprotocol.spec.McpSchema.Prompt;
import io.modelcontextprotocol.spec.McpSchema.PromptArgument;
import io.modelcontextprotocol.spec.McpSchema.ReadResourceResult;
import io.modelcontextprotocol.spec.McpSchema.ResourceReference;
import io.modelcontextprotocol.spec.McpSchema.PromptReference;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
import io.modelcontextprotocol.spec.McpError;

/**
* Tests for completion functionality with context support.
* @author Surbhi Bansal
*/
class McpCompletionTests {

private HttpServletSseServerTransportProvider mcpServerTransportProvider;

private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message";

McpClient.SyncSpec clientBuilder;

private Tomcat tomcat;

@BeforeEach
public void before() {
// Create and con figure the transport provider
mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder()
.objectMapper(new ObjectMapper())
.messageEndpoint(CUSTOM_MESSAGE_ENDPOINT)
.build();

tomcat = TomcatTestUtil.createTomcatServer("", 3400, mcpServerTransportProvider);
try {
tomcat.start();
assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED);
}
catch (Exception e) {
throw new RuntimeException("Failed to start Tomcat", e);
}

this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + 3400).build());
}

@AfterEach
public void after() {
if (mcpServerTransportProvider != null) {
mcpServerTransportProvider.closeGracefully().block();
}
if (tomcat != null) {
try {
tomcat.stop();
tomcat.destroy();
}
catch (LifecycleException e) {
throw new RuntimeException("Failed to stop Tomcat", e);
}
}
}

@Test
void testCompletionHandlerReceivesContext() {
AtomicReference<CompleteRequest> receivedRequest = new AtomicReference<>();
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (exchange, request) -> {
receivedRequest.set(request);
return new CompleteResult(new CompleteResult.CompleteCompletion(List.of("test-completion"), 1, false));
};

ResourceReference resourceRef = new ResourceReference("ref/resource", "test://resource/{param}");

McpSchema.Resource resource = new McpSchema.Resource("test://resource/{param}", "Test Resource",
"A resource for testing", "text/plain", 123L, null);

var mcpServer = McpServer.sync(mcpServerTransportProvider)
.capabilities(ServerCapabilities.builder().completions().build())
.resources(new McpServerFeatures.SyncResourceSpecification(resource,
(exchange, req) -> new ReadResourceResult(List.of())))
.completions(new McpServerFeatures.SyncCompletionSpecification(resourceRef, completionHandler))
.build();

try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
.build();) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// Test with context
CompleteRequest request = new CompleteRequest(resourceRef,
new CompleteRequest.CompleteArgument("param", "test"),
new CompleteRequest.CompleteContext(Map.of("previous", "value")));

CompleteResult result = mcpClient.completeCompletion(request);

// Verify handler received the context
assertThat(receivedRequest.get().context()).isNotNull();
assertThat(receivedRequest.get().context().arguments()).containsEntry("previous", "value");
assertThat(result.completion().values()).containsExactly("test-completion");
}

mcpServer.close();
}

@Test
void testCompletionBackwardCompatibility() {
AtomicReference<Boolean> contextWasNull = new AtomicReference<>(false);
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (exchange, request) -> {
contextWasNull.set(request.context() == null);
return new CompleteResult(
new CompleteResult.CompleteCompletion(List.of("no-context-completion"), 1, false));
};

McpSchema.Prompt prompt = new Prompt("test-prompt", "this is a test prompt",
List.of(new PromptArgument("arg", "string", false)));

var mcpServer = McpServer.sync(mcpServerTransportProvider)
.capabilities(ServerCapabilities.builder().completions().build())
.prompts(new McpServerFeatures.SyncPromptSpecification(prompt,
(mcpSyncServerExchange, getPromptRequest) -> null))
.completions(new McpServerFeatures.SyncCompletionSpecification(
new PromptReference("ref/prompt", "test-prompt"), completionHandler))
.build();

try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
.build();) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// Test without context
CompleteRequest request = new CompleteRequest(new PromptReference("ref/prompt", "test-prompt"),
new CompleteRequest.CompleteArgument("arg", "val"), null);

CompleteResult result = mcpClient.completeCompletion(request);

// Verify context was null
assertThat(contextWasNull.get()).isTrue();
assertThat(result.completion().values()).containsExactly("no-context-completion");
}

mcpServer.close();
}

@Test
void testDependentCompletionScenario() {
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (exchange, request) -> {
// Simulate database/table completion scenario
if (request.ref() instanceof ResourceReference resourceRef) {
if ("db://{database}/{table}".equals(resourceRef.uri())) {
if ("database".equals(request.argument().name())) {
// Complete database names
return new CompleteResult(new CompleteResult.CompleteCompletion(
List.of("users_db", "products_db", "analytics_db"), 3, false));
}
else if ("table".equals(request.argument().name())) {
// Complete table names based on selected database
if (request.context() != null && request.context().arguments() != null) {
String db = request.context().arguments().get("database");
if ("users_db".equals(db)) {
return new CompleteResult(new CompleteResult.CompleteCompletion(
List.of("users", "sessions", "permissions"), 3, false));
}
else if ("products_db".equals(db)) {
return new CompleteResult(new CompleteResult.CompleteCompletion(
List.of("products", "categories", "inventory"), 3, false));
}
}
}
}
}
return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false));
};

McpSchema.Resource resource = new McpSchema.Resource("db://{database}/{table}", "Database Table",
"Resource representing a table in a database", "application/json", 456L, null);

var mcpServer = McpServer.sync(mcpServerTransportProvider)
.capabilities(ServerCapabilities.builder().completions().build())
.resources(new McpServerFeatures.SyncResourceSpecification(resource,
(exchange, req) -> new ReadResourceResult(List.of())))
.completions(new McpServerFeatures.SyncCompletionSpecification(
new ResourceReference("ref/resource", "db://{database}/{table}"), completionHandler))
.build();

try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
.build();) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// First, complete database
CompleteRequest dbRequest = new CompleteRequest(
new ResourceReference("ref/resource", "db://{database}/{table}"),
new CompleteRequest.CompleteArgument("database", ""), null);

CompleteResult dbResult = mcpClient.completeCompletion(dbRequest);
assertThat(dbResult.completion().values()).contains("users_db", "products_db");

// Then complete table with database context
CompleteRequest tableRequest = new CompleteRequest(
new ResourceReference("ref/resource", "db://{database}/{table}"),
new CompleteRequest.CompleteArgument("table", ""),
new CompleteRequest.CompleteContext(Map.of("database", "users_db")));

CompleteResult tableResult = mcpClient.completeCompletion(tableRequest);
assertThat(tableResult.completion().values()).containsExactly("users", "sessions", "permissions");

// Different database gives different tables
CompleteRequest tableRequest2 = new CompleteRequest(
new ResourceReference("ref/resource", "db://{database}/{table}"),
new CompleteRequest.CompleteArgument("table", ""),
new CompleteRequest.CompleteContext(Map.of("database", "products_db")));

CompleteResult tableResult2 = mcpClient.completeCompletion(tableRequest2);
assertThat(tableResult2.completion().values()).containsExactly("products", "categories", "inventory");
}

mcpServer.close();
}

@Test
void testCompletionErrorOnMissingContext() {
BiFunction<McpSyncServerExchange, CompleteRequest, CompleteResult> completionHandler = (exchange, request) -> {
if (request.ref() instanceof ResourceReference resourceRef) {
if ("db://{database}/{table}".equals(resourceRef.uri())) {
if ("table".equals(request.argument().name())) {
// Check if database context is provided
if (request.context() == null || request.context().arguments() == null
|| !request.context().arguments().containsKey("database")) {
throw new McpError("Please select a database first to see available tables");
}
// Normal completion if context is provided
String db = request.context().arguments().get("database");
if ("test_db".equals(db)) {
return new CompleteResult(new CompleteResult.CompleteCompletion(
List.of("users", "orders", "products"), 3, false));
}
}
}
}
return new CompleteResult(new CompleteResult.CompleteCompletion(List.of(), 0, false));
};

McpSchema.Resource resource = new McpSchema.Resource("db://{database}/{table}", "Database Table",
"Resource representing a table in a database", "application/json", 456L, null);

var mcpServer = McpServer.sync(mcpServerTransportProvider)
.capabilities(ServerCapabilities.builder().completions().build())
.resources(new McpServerFeatures.SyncResourceSpecification(resource,
(exchange, req) -> new ReadResourceResult(List.of())))
.completions(new McpServerFeatures.SyncCompletionSpecification(
new ResourceReference("ref/resource", "db://{database}/{table}"), completionHandler))
.build();

try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample" + "client", "0.0.0"))
.build();) {
InitializeResult initResult = mcpClient.initialize();
assertThat(initResult).isNotNull();

// Try to complete table without database context - should raise error
CompleteRequest requestWithoutContext = new CompleteRequest(
new ResourceReference("ref/resource", "db://{database}/{table}"),
new CompleteRequest.CompleteArgument("table", ""), null);

assertThatExceptionOfType(McpError.class)
.isThrownBy(() -> mcpClient.completeCompletion(requestWithoutContext))
.withMessageContaining("Please select a database first");

// Now complete with proper context - should work normally
CompleteRequest requestWithContext = new CompleteRequest(
new ResourceReference("ref/resource", "db://{database}/{table}"),
new CompleteRequest.CompleteArgument("table", ""),
new CompleteRequest.CompleteContext(Map.of("database", "test_db")));

CompleteResult resultWithContext = mcpClient.completeCompletion(requestWithContext);
assertThat(resultWithContext.completion().values()).containsExactly("users", "orders", "products");
}

mcpServer.close();
}

}