Skip to content

[auth]: revision of support oauth client_secret_basic / none / custom methods #723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 41 additions & 20 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
auth,
type OAuthClientProvider,
} from "./auth.js";
import { OAuthMetadata } from 'src/shared/auth.js';

// Mock fetch globally
const mockFetch = jest.fn();
Expand Down Expand Up @@ -232,7 +233,7 @@ describe("OAuth Authorization", () => {
ok: false,
status: 404,
});

// Second call (root fallback) succeeds
mockFetch.mockResolvedValueOnce({
ok: true,
Expand All @@ -242,17 +243,17 @@ describe("OAuth Authorization", () => {

const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name");
expect(metadata).toEqual(validMetadata);

const calls = mockFetch.mock.calls;
expect(calls.length).toBe(2);

// First call should be path-aware
const [firstUrl, firstOptions] = calls[0];
expect(firstUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server/path/name");
expect(firstOptions.headers).toEqual({
"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION
});

// Second call should be root fallback
const [secondUrl, secondOptions] = calls[1];
expect(secondUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server");
Expand All @@ -267,7 +268,7 @@ describe("OAuth Authorization", () => {
ok: false,
status: 404,
});

// Second call (root fallback) also returns 404
mockFetch.mockResolvedValueOnce({
ok: false,
Expand All @@ -276,7 +277,7 @@ describe("OAuth Authorization", () => {

const metadata = await discoverOAuthMetadata("https://auth.example.com/path/name");
expect(metadata).toBeUndefined();

const calls = mockFetch.mock.calls;
expect(calls.length).toBe(2);
});
Expand All @@ -290,10 +291,10 @@ describe("OAuth Authorization", () => {

const metadata = await discoverOAuthMetadata("https://auth.example.com/");
expect(metadata).toBeUndefined();

const calls = mockFetch.mock.calls;
expect(calls.length).toBe(1); // Should not attempt fallback

const [url] = calls[0];
expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server");
});
Expand All @@ -307,24 +308,24 @@ describe("OAuth Authorization", () => {

const metadata = await discoverOAuthMetadata("https://auth.example.com");
expect(metadata).toBeUndefined();

const calls = mockFetch.mock.calls;
expect(calls.length).toBe(1); // Should not attempt fallback

const [url] = calls[0];
expect(url.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server");
});

it("falls back when path-aware discovery encounters CORS error", async () => {
// First call (path-aware) fails with TypeError (CORS)
mockFetch.mockImplementationOnce(() => Promise.reject(new TypeError("CORS error")));

// Retry path-aware without headers (simulating CORS retry)
mockFetch.mockResolvedValueOnce({
ok: false,
status: 404,
});

// Second call (root fallback) succeeds
mockFetch.mockResolvedValueOnce({
ok: true,
Expand All @@ -334,10 +335,10 @@ describe("OAuth Authorization", () => {

const metadata = await discoverOAuthMetadata("https://auth.example.com/deep/path");
expect(metadata).toEqual(validMetadata);

const calls = mockFetch.mock.calls;
expect(calls.length).toBe(3);

// Final call should be root fallback
const [lastUrl, lastOptions] = calls[2];
expect(lastUrl.toString()).toBe("https://auth.example.com/.well-known/oauth-authorization-server");
Expand Down Expand Up @@ -588,6 +589,13 @@ describe("OAuth Authorization", () => {
refresh_token: "refresh123",
};

const validMetadata = {
issuer: "https://auth.example.com",
authorization_endpoint: "https://auth.example.com/authorize",
token_endpoint: "https://auth.example.com/token",
response_types_supported: ["code"]
};

const validClientInfo = {
client_id: "client123",
client_secret: "secret123",
Expand Down Expand Up @@ -641,13 +649,15 @@ describe("OAuth Authorization", () => {
});

const tokens = await exchangeAuthorization("https://auth.example.com", {
metadata: validMetadata,
clientInformation: validClientInfo,
authorizationCode: "code123",
codeVerifier: "verifier123",
redirectUri: "http://localhost:3000/callback",
addClientAuthentication: (url: URL, headers: Headers, params: URLSearchParams) => {
addClientAuthentication: (headers: Headers, params: URLSearchParams, url: string | URL, metadata: OAuthMetadata) => {
headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret));
params.set("example_url", url.toString());
params.set("example_url", typeof url === 'string' ? url : url.toString());
params.set("example_metadata", metadata.authorization_endpoint);
params.set("example_param", "example_value");
},
});
Expand All @@ -671,7 +681,8 @@ describe("OAuth Authorization", () => {
expect(body.get("code_verifier")).toBe("verifier123");
expect(body.get("client_id")).toBeNull();
expect(body.get("redirect_uri")).toBe("http://localhost:3000/callback");
expect(body.get("example_url")).toBe("https://auth.example.com/token");
expect(body.get("example_url")).toBe("https://auth.example.com");
expect(body.get("example_metadata")).toBe("https://auth.example.com/authorize");
expect(body.get("example_param")).toBe("example_value");
expect(body.get("client_secret")).toBeNull();
});
Expand Down Expand Up @@ -724,6 +735,13 @@ describe("OAuth Authorization", () => {
refresh_token: "newrefresh123",
};

const validMetadata = {
issuer: "https://auth.example.com",
authorization_endpoint: "https://auth.example.com/authorize",
token_endpoint: "https://auth.example.com/token",
response_types_supported: ["code"]
};

const validClientInfo = {
client_id: "client123",
client_secret: "secret123",
Expand Down Expand Up @@ -773,11 +791,13 @@ describe("OAuth Authorization", () => {
});

const tokens = await refreshAuthorization("https://auth.example.com", {
metadata: validMetadata,
clientInformation: validClientInfo,
refreshToken: "refresh123",
addClientAuthentication: (url: URL, headers: Headers, params: URLSearchParams) => {
addClientAuthentication: (headers: Headers, params: URLSearchParams, url: string | URL, metadata: OAuthMetadata) => {
headers.set("Authorization", "Basic " + btoa(validClientInfo.client_id + ":" + validClientInfo.client_secret));
params.set("example_url", url.toString());
params.set("example_url", typeof url === 'string' ? url : url.toString());
params.set("example_metadata", metadata.authorization_endpoint);
params.set("example_param", "example_value");
},
});
Expand All @@ -799,7 +819,8 @@ describe("OAuth Authorization", () => {
expect(body.get("grant_type")).toBe("refresh_token");
expect(body.get("refresh_token")).toBe("refresh123");
expect(body.get("client_id")).toBeNull();
expect(body.get("example_url")).toBe("https://auth.example.com/token");
expect(body.get("example_url")).toBe("https://auth.example.com");
expect(body.get("example_metadata")).toBe("https://auth.example.com/authorize");
expect(body.get("example_param")).toBe("example_value");
expect(body.get("client_secret")).toBeNull();
});
Expand Down
42 changes: 21 additions & 21 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,25 @@ export interface OAuthClientProvider {
* the authorization result.
*/
codeVerifier(): string | Promise<string>;

/**
* Adds custom client authentication to OAuth token requests.
*
*
* This optional method allows implementations to customize how client credentials
* are included in token exchange and refresh requests. When provided, this method
* is called instead of the default authentication logic, giving full control over
* the authentication mechanism.
*
*
* Common use cases include:
* - Supporting authentication methods beyond the standard OAuth 2.0 methods
* - Adding custom headers for proprietary authentication schemes
* - Implementing client assertion-based authentication (e.g., JWT bearer tokens)
*
*
* @param url - The token endpoint URL being called
* @param headers - The request headers (can be modified to add authentication)
* @param params - The request body parameters (can be modified to add credentials)
*/
addClientAuthentication?(url: URL, headers: Headers, params: URLSearchParams): void | Promise<void>;
addClientAuthentication?(headers: Headers, params: URLSearchParams, url: string | URL, metadata?: OAuthMetadata): void | Promise<void>;

/**
* If defined, overrides the selection and validation of the
Expand All @@ -112,12 +112,12 @@ export class UnauthorizedError extends Error {

/**
* Determines the best client authentication method to use based on server support and client configuration.
*
*
* Priority order (highest to lowest):
* 1. client_secret_basic (if client secret is available)
* 2. client_secret_post (if client secret is available)
* 3. none (for public clients)
*
*
* @param clientInformation - OAuth client information containing credentials
* @param supportedMethods - Authentication methods supported by the authorization server
* @returns The selected authentication method
Expand All @@ -127,7 +127,7 @@ function selectClientAuthMethod(
supportedMethods: string[]
): string {
const hasClientSecret = !!clientInformation.client_secret;

// If server doesn't specify supported methods, use RFC 6749 defaults
if (supportedMethods.length === 0) {
return hasClientSecret ? "client_secret_post" : "none";
Expand All @@ -137,11 +137,11 @@ function selectClientAuthMethod(
if (hasClientSecret && supportedMethods.includes("client_secret_basic")) {
return "client_secret_basic";
}

if (hasClientSecret && supportedMethods.includes("client_secret_post")) {
return "client_secret_post";
}

if (supportedMethods.includes("none")) {
return "none";
}
Expand All @@ -152,12 +152,12 @@ function selectClientAuthMethod(

/**
* Applies client authentication to the request based on the specified method.
*
*
* Implements OAuth 2.1 client authentication methods:
* - client_secret_basic: HTTP Basic authentication (RFC 6749 Section 2.3.1)
* - client_secret_post: Credentials in request body (RFC 6749 Section 2.3.1)
* - none: Public client authentication (RFC 6749 Section 2.1)
*
*
* @param method - The authentication method to use
* @param clientInformation - OAuth client information containing credentials
* @param headers - HTTP headers object to modify
Expand Down Expand Up @@ -197,7 +197,7 @@ function applyBasicAuth(clientId: string, clientSecret: string | undefined, head
if (!clientSecret) {
throw new Error("client_secret_basic authentication requires a client_secret");
}

const credentials = btoa(`${clientId}:${clientSecret}`);
headers.set("Authorization", `Basic ${credentials}`);
}
Expand Down Expand Up @@ -593,11 +593,11 @@ export async function startAuthorization(

/**
* Exchanges an authorization code for an access token with the given server.
*
*
* Supports multiple client authentication methods as specified in OAuth 2.1:
* - Automatically selects the best authentication method based on server support
* - Falls back to appropriate defaults when server metadata is unavailable
*
*
* @param authorizationServerUrl - The authorization server's base URL
* @param options - Configuration object containing client info, auth code, etc.
* @returns Promise resolving to OAuth tokens
Expand Down Expand Up @@ -650,12 +650,12 @@ export async function exchangeAuthorization(
});

if (addClientAuthentication) {
addClientAuthentication(tokenUrl, headers, params);
addClientAuthentication(headers, params, authorizationServerUrl, metadata);
} else {
// Determine and apply client authentication method
const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? [];
const authMethod = selectClientAuthMethod(clientInformation, supportedMethods);

applyClientAuthentication(authMethod, clientInformation, headers, params);
}

Expand All @@ -678,11 +678,11 @@ export async function exchangeAuthorization(

/**
* Exchange a refresh token for an updated access token.
*
*
* Supports multiple client authentication methods as specified in OAuth 2.1:
* - Automatically selects the best authentication method based on server support
* - Preserves the original refresh token if a new one is not returned
*
*
* @param authorizationServerUrl - The authorization server's base URL
* @param options - Configuration object containing client info, refresh token, etc.
* @returns Promise resolving to OAuth tokens (preserves original refresh_token if not replaced)
Expand Down Expand Up @@ -732,12 +732,12 @@ export async function refreshAuthorization(
});

if (addClientAuthentication) {
addClientAuthentication(tokenUrl, headers, params);
addClientAuthentication(headers, params, authorizationServerUrl, metadata);
} else {
// Determine and apply client authentication method
const supportedMethods = metadata?.token_endpoint_auth_methods_supported ?? [];
const authMethod = selectClientAuthMethod(clientInformation, supportedMethods);

applyClientAuthentication(authMethod, clientInformation, headers, params);
}

Expand Down