diff --git a/README.md b/README.md index 9b5d061..70bb011 100644 --- a/README.md +++ b/README.md @@ -14,15 +14,13 @@ https://github.com/user-attachments/assets/c36a72e0-f790-4b3f-8720-294ab7f5f6eb - -This repository contains experimental Model Context Protocol (or MCP) servers for interacting with Algolia APIs. We're sharing it for you to explore and experiment with. -Feel free to use it, fork it, or build on top of it — but just know that it's not officially supported by Algolia and isn't covered under our SLA. +This repository contains experimental Model Context Protocol (or MCP) servers for interacting with Algolia APIs. We're sharing it for you to explore and experiment with. +Feel free to use it, fork it, or build on top of it — but just know that it's not officially supported by Algolia and isn't covered under our SLA. We might update it, break it, or remove it entirely at any time. If you customize or configure things here, there's a chance that work could be lost. Also, using MCP in production could affect your Algolia usage. If you have feedback or ideas (even code!), we'd love to hear it. Just know that we might use it to help improve our products. This project is provided "as is" and "as available," with no guarantees or warranties. To be super clear: MCP isn't considered an "API Client" for SLA purposes. - ## ✨ Quick Start 1. **Download** the latest release from our [GitHub Releases](https://github.com/algolia/mcp-node/releases) @@ -48,11 +46,13 @@ Algolia Node.js MCP enables natural language interactions with your Algolia data Here are some example prompts to get you started: ### Account Management + ``` "What is the email address associated with my Algolia account?" ``` ### Applications + ``` "List all my Algolia apps." "List all the indices are in my 'e-commerce' application and format them into a table sorted by entries." @@ -60,6 +60,7 @@ Here are some example prompts to get you started: ``` ### Search & Indexing + ``` "Search my 'products' index for Nike shoes under $100." "Add the top 10 programming books to my 'library' index using their ISBNs as objectIDs." @@ -67,12 +68,14 @@ Here are some example prompts to get you started: ``` ### Analytics & Insights + ``` "What's the no-results rate for my 'products' index in the DE region? Generate a graph using React and Recharts." "Show me the top 10 searches with no results in the DE region from last week." ``` ### Monitoring & Performance + ``` "Are there any ongoing incidents at Algolia?" "What's the current latency for my 'e-commerce' index?" @@ -101,7 +104,7 @@ Here are some example prompts to get you started: ### Windows & Linux -*Coming soon.* +_Coming soon._ ## ⚙️ Configuration @@ -149,9 +152,9 @@ Usage: algolia-mcp start-server [options] Starts the Algolia MCP server Options: - -o, --allow-tools Comma separated list of tool ids (default: - ["listIndices","getSettings","searchSingleIndex","getTopSearches","getTopHits","getNoResultsRate"]) - -h, --help display help for command + -t, --allow-tools Comma separated list of tool ids (default: getUserInfo,getApplications,...,listIndices) + --credentials Application ID and associated API key to use. Optional: the MCP will authenticate you if unspecified, giving you access to all your applications. + -h, --help display help for command ``` ## 🛠 Development @@ -164,6 +167,7 @@ Options: ### Setup Development Environment 1. Clone the repository: + ```sh git clone https://github.com/algolia/mcp-node cd mcp-node @@ -199,6 +203,7 @@ npm run build -- --outfile dist/algolia-mcp Use the [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) for testing and debugging: 1. Run the debug script: + ```sh cd mcp-node npm run debug @@ -219,6 +224,7 @@ Use the [MCP Inspector](https://modelcontextprotocol.io/docs/tools/inspector) fo ### Logs and Diagnostics Log files are stored in: + - macOS: `~/Library/Logs/algolia-mcp/` - Windows: `%APPDATA%\algolia-mcp\logs\` - Linux: `~/.config/algolia-mcp/logs/` diff --git a/src/DashboardApi.ts b/src/DashboardApi.ts index f950aa6..ff77e31 100644 --- a/src/DashboardApi.ts +++ b/src/DashboardApi.ts @@ -83,7 +83,7 @@ const CreateApiKeyResponse = z.object({ }); type CreateApiKeyResponse = z.infer; -const ACL = [ +export const REQUIRED_ACLS = [ "search", "listIndexes", "analytics", @@ -123,7 +123,8 @@ export class DashboardApi { const apiKeys = this.#options.appState.get("apiKeys"); let apiKey: string | undefined = apiKeys[applicationId]; - const shouldCreateApiKey = !apiKey || !(await this.#hasRightAcl(applicationId, apiKey, ACL)); + const shouldCreateApiKey = + !apiKey || !(await this.#hasRightAcl(applicationId, apiKey, REQUIRED_ACLS)); if (shouldCreateApiKey) { apiKey = await this.#createApiKey(applicationId); @@ -148,7 +149,7 @@ export class DashboardApi { { method: "POST", body: JSON.stringify({ - acl: ACL, + acl: REQUIRED_ACLS, description: "API Key created by and for the Algolia MCP Server", }), }, diff --git a/src/app.ts b/src/app.ts index dc2992f..25541f3 100644 --- a/src/app.ts +++ b/src/app.ts @@ -1,6 +1,6 @@ import { Command } from "commander"; -import { type StartServerOptions } from "./commands/start-server.ts"; import { type ListToolsOptions } from "./commands/list-tools.ts"; +import { ZodError } from "zod"; const program = new Command("algolia-mcp"); @@ -58,13 +58,43 @@ const ALLOW_TOOLS_OPTIONS_TUPLE = [ DEFAULT_ALLOW_TOOLS, ] as const; +function formatErrorForCli(error: unknown): string { + if (error instanceof ZodError) { + return [...error.errors.map((e) => `- ${e.path.join(".") || ""}: ${e.message}`)].join( + "\n", + ); + } + + if (error instanceof Error) { + return error.message; + } + + return "Unknown error"; +} + program .command("start-server", { isDefault: true }) .description("Starts the Algolia MCP server") .option(...ALLOW_TOOLS_OPTIONS_TUPLE) - .action(async (opts: StartServerOptions) => { - const { startServer } = await import("./commands/start-server.ts"); - await startServer(opts); + .option( + "--credentials ", + "Application ID and associated API key to use. Optional: the MCP will authenticate you if unspecified, giving you access to all your applications.", + (val) => { + const [applicationId, apiKey] = val.split(":"); + if (!applicationId || !apiKey) { + throw new Error("Invalid credentials format. Use applicationId:apiKey"); + } + return { applicationId, apiKey }; + }, + ) + .action(async (opts) => { + try { + const { startServer } = await import("./commands/start-server.ts"); + await startServer(opts); + } catch (error) { + console.error(formatErrorForCli(error)); + process.exit(1); + } }); program diff --git a/src/commands/start-server.test.ts b/src/commands/start-server.test.ts new file mode 100644 index 0000000..1d940d4 --- /dev/null +++ b/src/commands/start-server.test.ts @@ -0,0 +1,154 @@ +import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; + +import { startServer } from "./start-server.ts"; +import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js"; +import { setupServer } from "msw/node"; +import { http } from "msw"; +import { ZodError } from "zod"; +import type { AppState } from "../appState.ts"; +import { AppStateManager } from "../appState.ts"; +import { REQUIRED_ACLS } from "../DashboardApi.ts"; + +const mswServer = setupServer(); + +beforeAll(() => mswServer.listen()); +afterEach(() => mswServer.resetHandlers()); +afterAll(() => mswServer.close()); + +describe("when specifying credentials flag", () => { + it("should throw if params are missing", async () => { + await expect( + startServer({ + // @ts-expect-error -- I'm testing missing params + credentials: { applicationId: "appId" }, + }), + ).rejects.toThrow(ZodError); + await expect( + startServer({ + // @ts-expect-error -- I'm testing missing params + credentials: { apiKey: "apiKey" }, + }), + ).rejects.toThrow(ZodError); + }); + + it("should not throw if both params are provided", async () => { + vi.spyOn(AppStateManager, "load").mockRejectedValue(new Error("Should not be called")); + const server = await startServer({ credentials: { applicationId: "appId", apiKey: "apiKey" } }); + + expect(AppStateManager.load).not.toHaveBeenCalled(); + + await server.close(); + }); + + it("should allow filtering tools", async () => { + mswServer.use( + http.put("https://appid.algolia.net/1/indexes/indexName/settings", () => + Response.json({ taskId: 123 }), + ), + ); + const client = new Client({ name: "test client", version: "1.0.0" }); + const server = await startServer({ + credentials: { + apiKey: "apiKey", + applicationId: "appId", + }, + allowTools: ["setSettings"], + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const { tools } = await client.listTools(); + + expect(tools).toHaveLength(1); + expect(tools[0].name).toBe("setSettings"); + + const result = await client.callTool({ + name: "setSettings", + arguments: { + indexName: "indexName", + requestBody: { + searchableAttributes: ["title"], + }, + }, + }); + + expect(result).toMatchInlineSnapshot(` + { + "content": [ + { + "text": "{"taskId":123}", + "type": "text", + }, + ], + } + `); + + await server.close(); + }); +}); + +describe("default behavior", () => { + beforeEach(() => { + const mockAppState: AppState = { + accessToken: "accessToken", + refreshToken: "refreshToken", + apiKeys: { + appId: "apiKey", + }, + }; + vi.spyOn(AppStateManager, "load").mockResolvedValue( + // @ts-expect-error -- It's just a partial mock + { + get: vi.fn((k: K) => mockAppState[k]), + update: vi.fn(), + }, + ); + }); + + it("should list dashboard tools", async () => { + const client = new Client({ name: "test client", version: "1.0.0" }); + const server = await startServer({}); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(AppStateManager.load).toHaveBeenCalled(); + + const { tools } = await client.listTools(); + expect(tools).toHaveLength(176); + expect(tools.some((t) => t.name === "getUserInfo")).toBe(true); + }); + + it("should fetch the api key automatically", async () => { + mswServer.use( + http.get("https://appid-dsn.algolia.net/1/keys/apiKey", () => + Response.json({ acl: REQUIRED_ACLS }), + ), + http.get("https://appid.algolia.net/1/indexes/indexName/settings", () => Response.json({})), + ); + const client = new Client({ name: "test client", version: "1.0.0" }); + const server = await startServer({ allowTools: ["getSettings"] }); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await client.callTool({ + name: "getSettings", + arguments: { + applicationId: "appId", + indexName: "indexName", + }, + }); + + expect(result).toMatchInlineSnapshot(` + { + "content": [ + { + "text": "{}", + "type": "text", + }, + ], + } + `); + }); +}); diff --git a/src/commands/start-server.ts b/src/commands/start-server.ts index c04f06a..5b9331d 100644 --- a/src/commands/start-server.ts +++ b/src/commands/start-server.ts @@ -12,6 +12,11 @@ import { registerGetApplications, } from "../tools/registerGetApplications.ts"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; +import type { + ProcessCallbackArguments, + ProcessInputSchema, + RequestMiddleware, +} from "../tools/registerOpenApi.ts"; import { registerOpenApiTools } from "../tools/registerOpenApi.ts"; import { CONFIG } from "../config.ts"; import { @@ -25,7 +30,7 @@ import { SearchSpec, UsageSpec, } from "../openApi.ts"; -import { type CliFilteringOptions, getToolFilter, isToolAllowed } from "../toolFilters.ts"; +import { CliFilteringOptionsSchema, getToolFilter, isToolAllowed } from "../toolFilters.ts"; import { operationId as SetAttributesForFacetingOperationId, registerSetAttributesForFaceting, @@ -36,11 +41,80 @@ import { } from "../tools/registerSetCustomRanking.ts"; import { CustomMcpServer } from "../CustomMcpServer.ts"; +import { z } from "zod"; + +export const StartServerOptionsSchema = CliFilteringOptionsSchema.extend({ + credentials: z + .object({ + applicationId: z.string(), + apiKey: z.string(), + }) + .optional(), +}); + +type StartServerOptions = z.infer; + +function makeRegionRequestMiddleware(dashboardApi: DashboardApi): RequestMiddleware { + return async ({ request, params }) => { + const application = await dashboardApi.getApplication(params.applicationId); + const region = application.data.attributes.log_region === "de" ? "eu" : "us"; + + const url = new URL(request.url); + const regionFromUrl = url.hostname.match(/data\.(.+)\.algolia.com/)?.[0]; + + if (regionFromUrl !== region) { + console.error("Had to adjust region from", regionFromUrl, "to", region); + url.hostname = `data.${region}.algolia.com`; + return new Request(url, request.clone()); + } + + return request; + }; +} + +export async function startServer(options: StartServerOptions): Promise { + const { credentials, ...opts } = StartServerOptionsSchema.parse(options); + const toolFilter = getToolFilter(opts); + + const server = new CustomMcpServer({ + name: "algolia", + version: CONFIG.version, + capabilities: { + resources: {}, + tools: {}, + }, + }); + + const regionHotFixMiddlewares: RequestMiddleware[] = []; + let processCallbackArguments: ProcessCallbackArguments; + const processInputSchema: ProcessInputSchema = (inputSchema) => { + // If we got it from the options, we don't need it from the AI + if (credentials && inputSchema.properties?.applicationId) { + delete inputSchema.properties.applicationId; + + if (Array.isArray(inputSchema.required)) { + inputSchema.required = inputSchema.required.filter((item) => item !== "applicationId"); + } + } -export type StartServerOptions = CliFilteringOptions; + return inputSchema; + }; -export async function startServer(opts: StartServerOptions) { - try { + if (credentials) { + processCallbackArguments = async (params, securityKeys) => { + const result = { ...params }; + + if (securityKeys.has("applicationId")) { + result.applicationId = credentials.applicationId; + } + + if (securityKeys.has("apiKey")) { + result.apiKey = credentials.apiKey; + } + + return result; + }; + } else { const appState = await AppStateManager.load(); if (!appState.get("accessToken")) { @@ -52,21 +126,19 @@ export async function startServer(opts: StartServerOptions) { }); } - const dashboardApi = new DashboardApi({ - baseUrl: CONFIG.dashboardApiBaseUrl, - appState, - }); + const dashboardApi = new DashboardApi({ baseUrl: CONFIG.dashboardApiBaseUrl, appState }); - const server = new CustomMcpServer({ - name: "algolia", - version: CONFIG.version, - capabilities: { - resources: {}, - tools: {}, - }, - }); + processCallbackArguments = async (params, securityKeys) => { + const result = { ...params }; + + if (securityKeys.has("apiKey")) { + result.apiKey = await dashboardApi.getApiKey(params.applicationId); + } - const toolFilter = getToolFilter(opts); + return result; + }; + + regionHotFixMiddlewares.push(makeRegionRequestMiddleware(dashboardApi)); // Dashboard API Tools if (isToolAllowed(GetUserInfoOperationId, toolFilter)) { @@ -77,131 +149,79 @@ export async function startServer(opts: StartServerOptions) { registerGetApplications(server, dashboardApi); } - // Search API Tools - registerOpenApiTools({ - server, - dashboardApi, - openApiSpec: SearchSpec, - toolFilter, - }); - - // Analytics API Tools - registerOpenApiTools({ - server, - dashboardApi, - openApiSpec: AnalyticsSpec, - toolFilter, - }); - - // Recommend API Tools - registerOpenApiTools({ - server, - dashboardApi, - openApiSpec: RecommendSpec, - toolFilter, - }); + // TODO: Make it available when with applicationId+apiKey mode too + if (isToolAllowed(SetAttributesForFacetingOperationId, toolFilter)) { + registerSetAttributesForFaceting(server, dashboardApi); + } - // AB Testing - registerOpenApiTools({ - server, - dashboardApi, - openApiSpec: ABTestingSpec, - toolFilter, - }); + if (isToolAllowed(SetCustomRankingOperationId, toolFilter)) { + registerSetCustomRanking(server, dashboardApi); + } + } - // Monitoring API Tools + for (const openApiSpec of [ + SearchSpec, + AnalyticsSpec, + RecommendSpec, + ABTestingSpec, + MonitoringSpec, + CollectionsSpec, + QuerySuggestionsSpec, + ]) { registerOpenApiTools({ server, - dashboardApi, - openApiSpec: MonitoringSpec, + processInputSchema, + processCallbackArguments, + openApiSpec, toolFilter, }); + } - // Usage - registerOpenApiTools({ - server, - dashboardApi, - openApiSpec: UsageSpec, - toolFilter, - requestMiddlewares: [ - // The Usage API expects `name` parameter as multiple values - // rather than comma-separated. - async ({ request }) => { - const url = new URL(request.url); - const nameParams = url.searchParams.get("name"); - - if (!nameParams) { - return new Request(url, request.clone()); - } - - const nameValues = nameParams.split(","); - - url.searchParams.delete("name"); - - nameValues.forEach((value) => { - url.searchParams.append("name", value); - }); - + // Usage + registerOpenApiTools({ + server, + processInputSchema, + processCallbackArguments, + openApiSpec: UsageSpec, + toolFilter, + requestMiddlewares: [ + // The Usage API expects `name` parameter as multiple values + // rather than comma-separated. + async ({ request }) => { + const url = new URL(request.url); + const nameParams = url.searchParams.get("name"); + + if (!nameParams) { return new Request(url, request.clone()); - }, - ], - }); + } - // Ingestion API Tools - registerOpenApiTools({ - server, - dashboardApi, - openApiSpec: IngestionSpec, - toolFilter, - requestMiddlewares: [ - // Dirty fix for Claud hallucinating regions - async ({ request, params }) => { - const application = await dashboardApi.getApplication(params.applicationId); - const region = application.data.attributes.log_region === "de" ? "eu" : "us"; - - const url = new URL(request.url); - const regionFromUrl = url.hostname.match(/data\.(.+)\.algolia.com/)?.[0]; - - if (regionFromUrl !== region) { - console.error("Had to adjust region from", regionFromUrl, "to", region); - url.hostname = `data.${region}.algolia.com`; - return new Request(url, request.clone()); - } - - return request; - }, - ], - }); + const nameValues = nameParams.split(","); - // Collections API Tools - registerOpenApiTools({ - server, - dashboardApi, - openApiSpec: CollectionsSpec, - toolFilter, - }); + url.searchParams.delete("name"); - // Query Suggestions API Tools - registerOpenApiTools({ - server, - dashboardApi, - openApiSpec: QuerySuggestionsSpec, - toolFilter, - }); + nameValues.forEach((value) => { + url.searchParams.append("name", value); + }); - // Custom settings Tools - if (isToolAllowed(SetAttributesForFacetingOperationId, toolFilter)) { - registerSetAttributesForFaceting(server, dashboardApi); - } - - if (isToolAllowed(SetCustomRankingOperationId, toolFilter)) { - registerSetCustomRanking(server, dashboardApi); - } - - const transport = new StdioServerTransport(); - await server.connect(transport); - } catch (err) { - console.error("Error starting server:", err); - process.exit(1); - } + return new Request(url, request.clone()); + }, + ], + }); + + // Ingestion API Tools + registerOpenApiTools({ + server, + processInputSchema, + processCallbackArguments, + openApiSpec: IngestionSpec, + toolFilter, + requestMiddlewares: [ + // Dirty fix for Claud hallucinating regions + ...regionHotFixMiddlewares, + ], + }); + + const transport = new StdioServerTransport(); + await server.connect(transport); + return server; } diff --git a/src/toolFilters.ts b/src/toolFilters.ts index 8581bd0..1e0d52e 100644 --- a/src/toolFilters.ts +++ b/src/toolFilters.ts @@ -1,7 +1,10 @@ -export type CliFilteringOptions = { - allowTools?: string[]; - denyTools?: string[]; -}; +import z from "zod"; + +export const CliFilteringOptionsSchema = z.object({ + allowTools: z.array(z.string()).optional(), + denyTools: z.array(z.string()).optional(), +}); +export type CliFilteringOptions = z.infer; export type ToolFilter = { allowedTools?: Set; diff --git a/src/tools/registerOpenApi.test.ts b/src/tools/registerOpenApi.test.ts index 7e06456..5f2b866 100644 --- a/src/tools/registerOpenApi.test.ts +++ b/src/tools/registerOpenApi.test.ts @@ -1,8 +1,8 @@ import { afterAll, afterEach, beforeAll, describe, expect, it, vi } from "vitest"; import { http, HttpResponse } from "msw"; import type { ToolFilter } from "../toolFilters.ts"; -import type { DashboardApi } from "../DashboardApi.ts"; import { ALL_SPECS, SearchSpec } from "../openApi.ts"; +import type { ProcessCallbackArguments } from "./registerOpenApi.ts"; import { registerOpenApiTools } from "./registerOpenApi.ts"; import { setupServer } from "msw/node"; import { InMemoryTransport } from "@modelcontextprotocol/sdk/inMemory.js"; @@ -21,6 +21,20 @@ beforeAll(() => mswServer.listen()); afterEach(() => mswServer.resetHandlers()); afterAll(() => mswServer.close()); +const processCallbackArguments: ProcessCallbackArguments = async (params, securityKeys) => { + const result = { ...params }; + + if (securityKeys.has("applicationId")) { + result.applicationId = "simba"; + } + + if (securityKeys.has("apiKey")) { + result.apiKey = "dummy_api_key"; + } + + return result; +}; + describe("registerOpenApiTools", () => { it("should generate a working getSettings tool", async () => { mswServer.use( @@ -32,13 +46,9 @@ describe("registerOpenApiTools", () => { const server = new CustomMcpServer({ name: "algolia", version: "1.0.0" }); const client = new Client({ name: "test client", version: "1.0.0" }); - const dashboardApiMock = { - getApiKey: vi.fn().mockResolvedValue("apiKey"), - } as unknown as DashboardApi; - registerOpenApiTools({ server, - dashboardApi: dashboardApiMock, + processCallbackArguments, openApiSpec: SearchSpec, }); @@ -77,13 +87,10 @@ describe("registerOpenApiTools", () => { }; const serverMock = { tool: vi.fn() }; - const dashboardApiMock = { - getApiKey: vi.fn().mockResolvedValue("apiKey"), - }; registerOpenApiTools({ server: serverMock, - dashboardApi: dashboardApiMock as unknown as DashboardApi, + processCallbackArguments, openApiSpec: SearchSpec, toolFilter, }); @@ -94,7 +101,7 @@ describe("registerOpenApiTools", () => { const jsonlResponse = `{ "searchableAttributes": ["title"] } { "searchableAttributes": ["genre"] }`; mswServer.use( - http.get("https://appid.algolia.net/1/indexes/indexName/settings", () => + http.get("https://simba.algolia.net/1/indexes/indexName/settings", () => HttpResponse.text(jsonlResponse), ), ); @@ -123,7 +130,7 @@ describe("registerOpenApiTools", () => { registerOpenApiTools({ server, - dashboardApi: {} as DashboardApi, + processCallbackArguments, openApiSpec: SearchSpec, toolFilter: { allowedTools: new Set(["getSettings", "setSettings", "browse"]), @@ -159,7 +166,7 @@ describe("registerOpenApiTools", () => { for (const openApiSpec of ALL_SPECS) { registerOpenApiTools({ server, - dashboardApi: {} as DashboardApi, + processCallbackArguments, openApiSpec, }); } @@ -175,13 +182,9 @@ describe("registerOpenApiTools", () => { const server = new CustomMcpServer({ name: "algolia", version: "1.0.0" }); const client = new Client({ name: "test client", version: "1.0.0" }); - const dashboardApiMock = { - getApiKey: vi.fn().mockResolvedValue("apiKey"), - } as unknown as DashboardApi; - registerOpenApiTools({ server, - dashboardApi: dashboardApiMock, + processCallbackArguments, openApiSpec: SearchSpec, toolFilter: { allowedTools: new Set(["getSettings"]), @@ -203,13 +206,9 @@ describe("registerOpenApiTools", () => { const server = new CustomMcpServer({ name: "algolia", version: "1.0.0" }); client = new Client({ name: "test client", version: "1.0.0" }); - const dashboardApiMock = { - getApiKey: vi.fn().mockResolvedValue("someKey"), - } as unknown as DashboardApi; - registerOpenApiTools({ server, - dashboardApi: dashboardApiMock, + processCallbackArguments, openApiSpec: SearchSpec, }); diff --git a/src/tools/registerOpenApi.ts b/src/tools/registerOpenApi.ts index 7ea0dfe..e6e5170 100644 --- a/src/tools/registerOpenApi.ts +++ b/src/tools/registerOpenApi.ts @@ -1,4 +1,3 @@ -import { type DashboardApi } from "../DashboardApi.ts"; import { isToolAllowed, type ToolFilter } from "../toolFilters.ts"; import type { Methods, OpenApiSpec, Operation, Parameter, SecurityScheme } from "../openApi.ts"; import { CONFIG } from "../config.ts"; @@ -12,20 +11,29 @@ export type RequestMiddleware = (opts: { params: Record; }) => Promise; +export type ProcessInputSchema = (inputSchema: InputJsonSchema) => InputJsonSchema; +export type ProcessCallbackArguments = ( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + params: Record, + securityKeys: Set, +) => Promise; + type OpenApiToolsOptions = { server: Pick; - dashboardApi: DashboardApi; openApiSpec: OpenApiSpec; toolFilter?: ToolFilter; requestMiddlewares?: Array; + processInputSchema?: ProcessInputSchema; + processCallbackArguments?: ProcessCallbackArguments; }; export async function registerOpenApiTools({ server, - dashboardApi, openApiSpec, toolFilter, requestMiddlewares, + processCallbackArguments, + processInputSchema, }: OpenApiToolsOptions) { for (const [path, methods] of Object.entries(openApiSpec.paths)) { for (const [method, operation] of Object.entries(methods)) { @@ -44,7 +52,7 @@ export async function registerOpenApiTools({ openApiSpec, method: method as Methods, operation, - dashboardApi, + processCallbackArguments, requestMiddlewares, securityKeys, }); @@ -63,6 +71,7 @@ export async function registerOpenApiTools({ addDefinitionsToInputSchema(inputSchema, openApiSpec); inputSchema.required = [...new Set(inputSchema.required)]; + processInputSchema?.(inputSchema); server.tool({ name: operation.operationId, @@ -81,7 +90,7 @@ type ToolCallbackBuildOptions = { method: Methods; operation: Operation; securityKeys: Set; - dashboardApi: DashboardApi; + processCallbackArguments?: ProcessCallbackArguments; requestMiddlewares?: Array; }; @@ -92,10 +101,15 @@ function buildToolCallback({ operation, requestMiddlewares, securityKeys, - dashboardApi, + processCallbackArguments, }: ToolCallbackBuildOptions) { // eslint-disable-next-line @typescript-eslint/no-explicit-any return async (params: Record): Promise => { + // eslint-disable-next-line no-param-reassign + params = processCallbackArguments + ? await processCallbackArguments(params, securityKeys) + : params; + const { requestBody } = params; if (method === "get" && requestBody) { @@ -138,13 +152,7 @@ function buildToolCallback({ throw new Error(`Unsupported security scheme type: ${securityScheme.type}`); } - let value: string; - - if (key === "apiKey") { - value = await dashboardApi.getApiKey(params.applicationId); - } else { - value = params[key]; - } + const value: string = params[key]; if (!value) { throw new Error(`Missing security parameter: ${key}`);