diff --git a/src/common/atlas/apiClient.ts b/src/common/atlas/apiClient.ts index 7e920392..56be6077 100644 --- a/src/common/atlas/apiClient.ts +++ b/src/common/atlas/apiClient.ts @@ -1,8 +1,8 @@ -import config from "../../config.js"; import createClient, { Client, FetchOptions, Middleware } from "openapi-fetch"; import { AccessToken, ClientCredentials } from "simple-oauth2"; import { ApiClientError } from "./apiClientError.js"; import { paths, operations } from "./openapi.js"; +import { packageInfo } from "../../packageInfo.js"; const ATLAS_API_VERSION = "2025-03-12"; @@ -67,7 +67,7 @@ export class ApiClient { baseUrl: options?.baseUrl || "https://cloud.mongodb.com/", userAgent: options?.userAgent || - `AtlasMCP/${config.version} (${process.platform}; ${process.arch}; ${process.env.HOSTNAME || "unknown"})`, + `AtlasMCP/${packageInfo.version} (${process.platform}; ${process.arch}; ${process.env.HOSTNAME || "unknown"})`, }; this.client = createClient({ diff --git a/src/config.ts b/src/config.ts index f5f18ca5..e55ca239 100644 --- a/src/config.ts +++ b/src/config.ts @@ -2,12 +2,11 @@ import path from "path"; import os from "os"; import argv from "yargs-parser"; -import packageJson from "../package.json" with { type: "json" }; import { ReadConcernLevel, ReadPreferenceMode, W } from "mongodb"; // If we decide to support non-string config options, we'll need to extend the mechanism for parsing // env variables. -interface UserConfig { +export interface UserConfig { apiBaseUrl?: string; apiClientId?: string; apiClientSecret?: string; @@ -33,19 +32,12 @@ const defaults: UserConfig = { disabledTools: [], }; -const mergedUserConfig = { +export const config = { ...defaults, ...getEnvConfig(), ...getCliConfig(), }; -const config = { - ...mergedUserConfig, - version: packageJson.version, -}; - -export default config; - function getLogPath(): string { const localDataPath = process.platform === "win32" diff --git a/src/index.ts b/src/index.ts index 944ee92a..60e2ba97 100644 --- a/src/index.ts +++ b/src/index.ts @@ -4,20 +4,25 @@ import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js" import logger from "./logger.js"; import { mongoLogId } from "mongodb-log-writer"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; -import config from "./config.js"; +import { config } from "./config.js"; import { Session } from "./session.js"; import { Server } from "./server.js"; +import { packageInfo } from "./packageInfo.js"; try { - const session = new Session(); + const session = new Session({ + apiBaseUrl: config.apiBaseUrl, + apiClientId: config.apiClientId, + apiClientSecret: config.apiClientSecret, + }); const mcpServer = new McpServer({ - name: "MongoDB Atlas", - version: config.version, + name: packageInfo.mcpServerName, + version: packageInfo.version, }); - const server = new Server({ mcpServer, session, + userConfig: config, }); const transport = new StdioServerTransport(); diff --git a/src/logger.ts b/src/logger.ts index 6682566a..425f56b9 100644 --- a/src/logger.ts +++ b/src/logger.ts @@ -1,6 +1,5 @@ import fs from "fs/promises"; import { MongoLogId, MongoLogManager, MongoLogWriter } from "mongodb-log-writer"; -import config from "./config.js"; import redact from "mongodb-redact"; import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { LoggingMessageNotification } from "@modelcontextprotocol/sdk/types.js"; @@ -98,11 +97,11 @@ class ProxyingLogger extends LoggerBase { const logger = new ProxyingLogger(); export default logger; -export async function initializeLogger(server: McpServer): Promise { - await fs.mkdir(config.logPath, { recursive: true }); +export async function initializeLogger(server: McpServer, logPath: string): Promise { + await fs.mkdir(logPath, { recursive: true }); const manager = new MongoLogManager({ - directory: config.logPath, + directory: logPath, retentionDays: 30, onwarn: console.warn, onerror: console.error, diff --git a/src/packageInfo.ts b/src/packageInfo.ts new file mode 100644 index 00000000..dea9214b --- /dev/null +++ b/src/packageInfo.ts @@ -0,0 +1,6 @@ +import packageJson from "../package.json" with { type: "json" }; + +export const packageInfo = { + version: packageJson.version, + mcpServerName: "MongoDB MCP Server", +}; diff --git a/src/server.ts b/src/server.ts index e3e399a0..85a8bc4f 100644 --- a/src/server.ts +++ b/src/server.ts @@ -5,15 +5,23 @@ import { AtlasTools } from "./tools/atlas/tools.js"; import { MongoDbTools } from "./tools/mongodb/tools.js"; import logger, { initializeLogger } from "./logger.js"; import { mongoLogId } from "mongodb-log-writer"; -import config from "./config.js"; +import { UserConfig } from "./config.js"; + +export interface ServerOptions { + session: Session; + userConfig: UserConfig; + mcpServer: McpServer; +} export class Server { public readonly session: Session; private readonly mcpServer: McpServer; + private readonly userConfig: UserConfig; - constructor({ mcpServer, session }: { mcpServer: McpServer; session: Session }) { - this.mcpServer = mcpServer; + constructor({ session, mcpServer, userConfig }: ServerOptions) { this.session = session; + this.mcpServer = mcpServer; + this.userConfig = userConfig; } async connect(transport: Transport) { @@ -22,7 +30,7 @@ export class Server { this.registerTools(); this.registerResources(); - await initializeLogger(this.mcpServer); + await initializeLogger(this.mcpServer, this.userConfig.logPath); await this.mcpServer.connect(transport); @@ -36,12 +44,12 @@ export class Server { private registerTools() { for (const tool of [...AtlasTools, ...MongoDbTools]) { - new tool(this.session).register(this.mcpServer); + new tool(this.session, this.userConfig).register(this.mcpServer); } } private registerResources() { - if (config.connectionString) { + if (this.userConfig.connectionString) { this.mcpServer.resource( "connection-string", "config://connection-string", @@ -52,7 +60,7 @@ export class Server { return { contents: [ { - text: `Preconfigured connection string: ${config.connectionString}`, + text: `Preconfigured connection string: ${this.userConfig.connectionString}`, uri: uri.href, }, ], diff --git a/src/session.ts b/src/session.ts index 7e7cb209..8ef1932d 100644 --- a/src/session.ts +++ b/src/session.ts @@ -1,22 +1,27 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { ApiClient, ApiClientCredentials } from "./common/atlas/apiClient.js"; -import config from "./config.js"; + +export interface SessionOptions { + apiBaseUrl?: string; + apiClientId?: string; + apiClientSecret?: string; +} export class Session { serviceProvider?: NodeDriverServiceProvider; apiClient: ApiClient; - constructor() { + constructor({ apiBaseUrl, apiClientId, apiClientSecret }: SessionOptions = {}) { const credentials: ApiClientCredentials | undefined = - config.apiClientId && config.apiClientSecret + apiClientId && apiClientSecret ? { - clientId: config.apiClientId, - clientSecret: config.apiClientSecret, + clientId: apiClientId, + clientSecret: apiClientSecret, } : undefined; this.apiClient = new ApiClient({ - baseUrl: config.apiBaseUrl, + baseUrl: apiBaseUrl, credentials, }); } diff --git a/src/tools/atlas/atlasTool.ts b/src/tools/atlas/atlasTool.ts index 0c2cc0cb..6ca5282d 100644 --- a/src/tools/atlas/atlasTool.ts +++ b/src/tools/atlas/atlasTool.ts @@ -1,16 +1,10 @@ import { ToolBase, ToolCategory } from "../tool.js"; -import { Session } from "../../session.js"; -import config from "../../config.js"; export abstract class AtlasToolBase extends ToolBase { - constructor(protected readonly session: Session) { - super(session); - } - protected category: ToolCategory = "atlas"; protected verifyAllowed(): boolean { - if (!config.apiClientId || !config.apiClientSecret) { + if (!this.config.apiClientId || !this.config.apiClientSecret) { return false; } return super.verifyAllowed(); diff --git a/src/tools/mongodb/metadata/connect.ts b/src/tools/mongodb/metadata/connect.ts index fad117da..746da9b3 100644 --- a/src/tools/mongodb/metadata/connect.ts +++ b/src/tools/mongodb/metadata/connect.ts @@ -2,7 +2,6 @@ import { z } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; -import config from "../../../config.js"; import { MongoError as DriverError } from "mongodb"; export class ConnectTool extends MongoDBToolBase { @@ -35,7 +34,7 @@ export class ConnectTool extends MongoDBToolBase { protected async execute({ options: optionsArr }: ToolArgs): Promise { const options = optionsArr?.[0]; let connectionString: string; - if (!options && !config.connectionString) { + if (!options && !this.config.connectionString) { return { content: [ { type: "text", text: "No connection details provided." }, @@ -46,7 +45,7 @@ export class ConnectTool extends MongoDBToolBase { if (!options) { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - connectionString = config.connectionString!; + connectionString = this.config.connectionString!; } else if ("connectionString" in options) { connectionString = options.connectionString; } else { @@ -72,9 +71,9 @@ export class ConnectTool extends MongoDBToolBase { // Sometimes the model will supply an incorrect connection string. If the user has configured // a different one as environment variable or a cli argument, suggest using that one instead. if ( - config.connectionString && + this.config.connectionString && error instanceof DriverError && - config.connectionString !== connectionString + this.config.connectionString !== connectionString ) { return { content: [ @@ -82,7 +81,7 @@ export class ConnectTool extends MongoDBToolBase { type: "text", text: `Failed to connect to MongoDB at '${connectionString}' due to error: '${error.message}.` + - `Your config lists a different connection string: '${config.connectionString}' - do you want to try connecting to it instead?`, + `Your config lists a different connection string: '${this.config.connectionString}' - do you want to try connecting to it instead?`, }, ], }; diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index b79c6b9f..d818c7ab 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -1,10 +1,8 @@ import { z } from "zod"; import { ToolArgs, ToolBase, ToolCategory } from "../tool.js"; -import { Session } from "../../session.js"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { ErrorCodes, MongoDBError } from "../../errors.js"; -import config from "../../config.js"; export const DbOperationArgs = { database: z.string().describe("Database name"), @@ -12,15 +10,11 @@ export const DbOperationArgs = { }; export abstract class MongoDBToolBase extends ToolBase { - constructor(session: Session) { - super(session); - } - protected category: ToolCategory = "mongodb"; protected async ensureConnected(): Promise { - if (!this.session.serviceProvider && config.connectionString) { - await this.connectToMongoDB(config.connectionString); + if (!this.session.serviceProvider && this.config.connectionString) { + await this.connectToMongoDB(this.config.connectionString); } if (!this.session.serviceProvider) { @@ -58,13 +52,13 @@ export abstract class MongoDBToolBase extends ToolBase { productDocsLink: "https://docs.mongodb.com/todo-mcp", productName: "MongoDB MCP", readConcern: { - level: config.connectOptions.readConcern, + level: this.config.connectOptions.readConcern, }, - readPreference: config.connectOptions.readPreference, + readPreference: this.config.connectOptions.readPreference, writeConcern: { - w: config.connectOptions.writeConcern, + w: this.config.connectOptions.writeConcern, }, - timeoutMS: config.connectOptions.timeoutMS, + timeoutMS: this.config.connectOptions.timeoutMS, }); this.session.serviceProvider = provider; diff --git a/src/tools/tool.ts b/src/tools/tool.ts index 73f3e853..a0d0f688 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -4,7 +4,7 @@ import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { Session } from "../session.js"; import logger from "../logger.js"; import { mongoLogId } from "mongodb-log-writer"; -import config from "../config.js"; +import { UserConfig } from "../config.js"; export type ToolArgs = z.objectOutputType; @@ -24,7 +24,10 @@ export abstract class ToolBase { protected abstract execute(...args: Parameters>): Promise; - protected constructor(protected session: Session) {} + constructor( + protected readonly session: Session, + protected readonly config: UserConfig + ) {} public register(server: McpServer): void { if (!this.verifyAllowed()) { @@ -54,11 +57,11 @@ export abstract class ToolBase { // Checks if a tool is allowed to run based on the config protected verifyAllowed(): boolean { let errorClarification: string | undefined; - if (config.disabledTools.includes(this.category)) { + if (this.config.disabledTools.includes(this.category)) { errorClarification = `its category, \`${this.category}\`,`; - } else if (config.disabledTools.includes(this.operationType)) { + } else if (this.config.disabledTools.includes(this.operationType)) { errorClarification = `its operation type, \`${this.operationType}\`,`; - } else if (config.disabledTools.includes(this.name)) { + } else if (this.config.disabledTools.includes(this.name)) { errorClarification = `it`; } diff --git a/tests/integration/helpers.ts b/tests/integration/helpers.ts index 28ddbb02..4e236b1a 100644 --- a/tests/integration/helpers.ts +++ b/tests/integration/helpers.ts @@ -4,12 +4,12 @@ import { Server } from "../../src/server.js"; import runner, { MongoCluster } from "mongodb-runner"; import path from "path"; import fs from "fs/promises"; -import { Session } from "../../src/session.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import { MongoClient, ObjectId } from "mongodb"; import { toIncludeAllMembers } from "jest-extended"; -import config from "../../src/config.js"; +import { config, UserConfig } from "../../src/config.js"; import { McpError } from "@modelcontextprotocol/sdk/types.js"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { Session } from "../../src/session.js"; interface ParameterInfo { name: string; @@ -29,7 +29,7 @@ export interface IntegrationTest { randomDbName: () => string; } -export function setupIntegrationTest(): IntegrationTest { +export function setupIntegrationTest(userConfig: UserConfig = config): IntegrationTest { let mongoCluster: runner.MongoCluster | undefined; let mongoClient: MongoClient | undefined; @@ -58,12 +58,19 @@ export function setupIntegrationTest(): IntegrationTest { } ); + const session = new Session({ + apiBaseUrl: userConfig.apiBaseUrl, + apiClientId: userConfig.apiClientId, + apiClientSecret: userConfig.apiClientSecret, + }); + mcpServer = new Server({ + session, + userConfig, mcpServer: new McpServer({ name: "test-server", version: "1.2.3", }), - session: new Session(), }); await mcpServer.connect(serverTransport); await mcpClient.connect(clientTransport); @@ -315,14 +322,3 @@ export function validateThrowsForInvalidArguments( } }); } - -export function describeAtlas(name: number | string | Function | jest.FunctionLike, fn: jest.EmptyFunction) { - if (!process.env.MDB_MCP_API_CLIENT_ID?.length || !process.env.MDB_MCP_API_CLIENT_SECRET?.length) { - return describe.skip("atlas", () => { - describe(name, fn); - }); - } - return describe("atlas", () => { - describe(name, fn); - }); -} diff --git a/tests/integration/server.test.ts b/tests/integration/server.test.ts index 8a0dde4d..5130b4b6 100644 --- a/tests/integration/server.test.ts +++ b/tests/integration/server.test.ts @@ -1,35 +1,61 @@ import { setupIntegrationTest } from "./helpers"; +import { config } from "../../src/config.js"; describe("Server integration test", () => { - const integration = setupIntegrationTest(); + describe("without atlas", () => { + const integration = setupIntegrationTest({ + ...config, + apiClientId: undefined, + apiClientSecret: undefined, + }); - describe("list capabilities", () => { - it("should return positive number of tools", async () => { + it("should return positive number of tools and have no atlas tools", async () => { const tools = await integration.mcpClient().listTools(); expect(tools).toBeDefined(); expect(tools.tools.length).toBeGreaterThan(0); + + const atlasTools = tools.tools.filter((tool) => tool.name.startsWith("atlas-")); + expect(atlasTools.length).toBeLessThanOrEqual(0); }); + }); + describe("with atlas", () => { + const integration = setupIntegrationTest({ + ...config, + apiClientId: "test", + apiClientSecret: "test", + }); + + describe("list capabilities", () => { + it("should return positive number of tools and have some atlas tools", async () => { + const tools = await integration.mcpClient().listTools(); + expect(tools).toBeDefined(); + expect(tools.tools.length).toBeGreaterThan(0); - it("should return no resources", async () => { - await expect(() => integration.mcpClient().listResources()).rejects.toMatchObject({ - message: "MCP error -32601: Method not found", + const atlasTools = tools.tools.filter((tool) => tool.name.startsWith("atlas-")); + expect(atlasTools.length).toBeGreaterThan(0); }); - }); - it("should return no prompts", async () => { - await expect(() => integration.mcpClient().listPrompts()).rejects.toMatchObject({ - message: "MCP error -32601: Method not found", + it("should return no resources", async () => { + await expect(() => integration.mcpClient().listResources()).rejects.toMatchObject({ + message: "MCP error -32601: Method not found", + }); }); - }); - it("should return capabilities", async () => { - const capabilities = integration.mcpClient().getServerCapabilities(); - expect(capabilities).toBeDefined(); - expect(capabilities?.completions).toBeUndefined(); - expect(capabilities?.experimental).toBeUndefined(); - expect(capabilities?.tools).toBeDefined(); - expect(capabilities?.logging).toBeDefined(); - expect(capabilities?.prompts).toBeUndefined(); + it("should return no prompts", async () => { + await expect(() => integration.mcpClient().listPrompts()).rejects.toMatchObject({ + message: "MCP error -32601: Method not found", + }); + }); + + it("should return capabilities", async () => { + const capabilities = integration.mcpClient().getServerCapabilities(); + expect(capabilities).toBeDefined(); + expect(capabilities?.completions).toBeUndefined(); + expect(capabilities?.experimental).toBeUndefined(); + expect(capabilities?.tools).toBeDefined(); + expect(capabilities?.logging).toBeDefined(); + expect(capabilities?.prompts).toBeUndefined(); + }); }); }); }); diff --git a/tests/integration/tools/mongodb/metadata/connect.test.ts b/tests/integration/tools/mongodb/metadata/connect.test.ts index d107885d..3f28a66d 100644 --- a/tests/integration/tools/mongodb/metadata/connect.test.ts +++ b/tests/integration/tools/mongodb/metadata/connect.test.ts @@ -1,6 +1,6 @@ import { getResponseContent, setupIntegrationTest, validateToolMetadata } from "../../../helpers.js"; -import config from "../../../../../src/config.js"; +import { config } from "../../../../../src/config.js"; describe("Connect tool", () => { const integration = setupIntegrationTest();