diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 2d116344..3e3abe68 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -382,6 +382,29 @@ describe("SSEClientTransport", () => { expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); + it("attaches custom header from provider on initial SSE connection", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + const customHeaders = { + "X-Custom-Header": "custom-value", + }; + + transport = new SSEClientTransport(resourceBaseUrl, { + authProvider: mockAuthProvider, + requestInit: { + headers: customHeaders, + }, + }); + + await transport.start(); + + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value"); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); + it("attaches auth header from provider on POST requests", async () => { mockAuthProvider.tokens.mockResolvedValue({ access_token: "test-token", diff --git a/src/client/sse.ts b/src/client/sse.ts index faffecc4..568a5159 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -106,10 +106,8 @@ export class SSEClientTransport implements Transport { return await this._startOrAuth(); } - private async _commonHeaders(): Promise { - const headers = { - ...this._requestInit?.headers, - } as HeadersInit & Record; + private async _commonHeaders(): Promise { + const headers: HeadersInit = {}; if (this._authProvider) { const tokens = await this._authProvider.tokens(); if (tokens) { @@ -120,24 +118,24 @@ export class SSEClientTransport implements Transport { headers["mcp-protocol-version"] = this._protocolVersion; } - return headers; + return new Headers( + { ...headers, ...this._requestInit?.headers } + ); } private _startOrAuth(): Promise { -const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch + const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch return new Promise((resolve, reject) => { this._eventSource = new EventSource( this._url.href, { ...this._eventSourceInit, fetch: async (url, init) => { - const headers = await this._commonHeaders() + const headers = await this._commonHeaders(); + headers.set("Accept", "text/event-stream"); const response = await fetchImpl(url, { ...init, - headers: new Headers({ - ...headers, - Accept: "text/event-stream" - }) + headers, }) if (response.status === 401 && response.headers.has('www-authenticate')) { @@ -238,8 +236,7 @@ const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typ } try { - const commonHeaders = await this._commonHeaders(); - const headers = new Headers(commonHeaders); + const headers = await this._commonHeaders(); headers.set("content-type", "application/json"); const init = { ...this._requestInit,