From 4c2dc3f1ee31ca1c691e97f7e68e841284e32e35 Mon Sep 17 00:00:00 2001 From: Peter Karman Date: Mon, 14 Jul 2025 15:21:48 -0700 Subject: [PATCH] feat(auth-timeout): make callback timeout configurable --- README.md | 11 +++ src/client.ts | 7 +- src/lib/coordination.ts | 11 ++- src/lib/types.ts | 2 + src/lib/utils.test.ts | 152 +++++++++++++++++++++++++++++++++++++++- src/lib/utils.ts | 16 ++++- src/proxy.ts | 5 +- 7 files changed, 196 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 8cce8d6..765b0aa 100644 --- a/README.md +++ b/README.md @@ -135,6 +135,17 @@ To bypass authentication, or to emit custom headers on all requests to your remo ] ``` +* To change the timeout for the OAuth callback (by default `30` seconds), add the `--auth-timeout` flag with a value in seconds. This is useful if the authentication process on the server side takes a long time. + +```json + "args": [ + "mcp-remote", + "https://remote.mcp.server/sse", + "--auth-timeout", + "60" + ] +``` + ### Transport Strategies MCP Remote supports different transport strategies when connecting to an MCP server. This allows you to control whether it uses Server-Sent Events (SSE) or HTTP transport, and in what order it tries them. diff --git a/src/client.ts b/src/client.ts index d9a7343..5a20433 100644 --- a/src/client.ts +++ b/src/client.ts @@ -36,6 +36,7 @@ async function runClient( host: string, staticOAuthClientMetadata: StaticOAuthClientMetadata, staticOAuthClientInfo: StaticOAuthClientInformationFull, + authTimeoutMs: number, ) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -44,7 +45,7 @@ async function runClient( const serverUrlHash = getServerUrlHash(serverUrl) // Create a lazy auth coordinator - const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events) + const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs) // Create the OAuth client provider const authProvider = new NodeOAuthClientProvider({ @@ -159,8 +160,8 @@ async function runClient( // Parse command-line arguments and run the client parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx client.ts [callback-port] [--debug]') - .then(({ serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo }) => { - return runClient(serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo) + .then(({ serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo, authTimeoutMs }) => { + return runClient(serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo, authTimeoutMs) }) .catch((error) => { console.error('Fatal error:', error) diff --git a/src/lib/coordination.ts b/src/lib/coordination.ts index 8b729f8..21feca3 100644 --- a/src/lib/coordination.ts +++ b/src/lib/coordination.ts @@ -129,7 +129,12 @@ export async function waitForAuthentication(port: number): Promise { * @param events The event emitter to use for signaling * @returns An AuthCoordinator object with an initializeAuth method */ -export function createLazyAuthCoordinator(serverUrlHash: string, callbackPort: number, events: EventEmitter): AuthCoordinator { +export function createLazyAuthCoordinator( + serverUrlHash: string, + callbackPort: number, + events: EventEmitter, + authTimeoutMs: number, +): AuthCoordinator { let authState: { server: Server; waitForAuthCode: () => Promise; skipBrowserAuth: boolean } | null = null return { @@ -144,7 +149,7 @@ export function createLazyAuthCoordinator(serverUrlHash: string, callbackPort: n if (DEBUG) debugLog('Initializing auth coordination on-demand', { serverUrlHash, callbackPort }) // Initialize auth using the existing coordinateAuth logic - authState = await coordinateAuth(serverUrlHash, callbackPort, events) + authState = await coordinateAuth(serverUrlHash, callbackPort, events, authTimeoutMs) if (DEBUG) debugLog('Auth coordination completed', { skipBrowserAuth: authState.skipBrowserAuth }) return authState }, @@ -162,6 +167,7 @@ export async function coordinateAuth( serverUrlHash: string, callbackPort: number, events: EventEmitter, + authTimeoutMs: number, ): Promise<{ server: Server; waitForAuthCode: () => Promise; skipBrowserAuth: boolean }> { if (DEBUG) debugLog('Coordinating authentication', { serverUrlHash, callbackPort }) @@ -228,6 +234,7 @@ export async function coordinateAuth( port: callbackPort, path: '/oauth/callback', events, + authTimeoutMs, }) // Get the actual port the server is running on diff --git a/src/lib/types.ts b/src/lib/types.ts index 3a310c7..db6a8ab 100644 --- a/src/lib/types.ts +++ b/src/lib/types.ts @@ -41,6 +41,8 @@ export interface OAuthCallbackServerOptions { path: string /** Event emitter to signal when auth code is received */ events: EventEmitter + /** Timeout in milliseconds for the auth callback server's long poll */ + authTimeoutMs?: number } // optional tatic OAuth client information diff --git a/src/lib/utils.test.ts b/src/lib/utils.test.ts index 7e819a1..e5a3e5d 100644 --- a/src/lib/utils.test.ts +++ b/src/lib/utils.test.ts @@ -1,3 +1,153 @@ -import { describe, it, expect } from 'vitest' +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' +import { parseCommandLineArgs, setupOAuthCallbackServerWithLongPoll } from './utils' +import { EventEmitter } from 'events' +import express from 'express' // All sanitizeUrl tests have been moved to the strict-url-sanitise package + +describe('parseCommandLineArgs', () => { + const mockUsage = 'Usage: test ' + const mockExit = vi.spyOn(process, 'exit').mockImplementation(() => { + throw new Error('process.exit called') + }) + + beforeEach(() => { + vi.clearAllMocks() + }) + + afterEach(() => { + mockExit.mockReset() + }) + + describe('--auth-timeout parsing', () => { + it('should use default timeout of 30000ms when no --auth-timeout flag is provided', async () => { + const args = ['https://example.com'] + const result = await parseCommandLineArgs(args, mockUsage) + + expect(result.authTimeoutMs).toBe(30000) + }) + + it('should parse valid timeout in seconds and convert to milliseconds', async () => { + const args = ['https://example.com', '--auth-timeout', '60'] + const result = await parseCommandLineArgs(args, mockUsage) + + expect(result.authTimeoutMs).toBe(60000) + }) + + it('should parse another valid timeout value', async () => { + const args = ['https://example.com', '--auth-timeout', '120'] + const result = await parseCommandLineArgs(args, mockUsage) + + expect(result.authTimeoutMs).toBe(120000) + }) + + it('should use default timeout when invalid timeout value is provided', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + const args = ['https://example.com', '--auth-timeout', 'invalid'] + const result = await parseCommandLineArgs(args, mockUsage) + + expect(result.authTimeoutMs).toBe(30000) + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining('Warning: Ignoring invalid auth timeout value: invalid. Must be a positive number.') + ) + + consoleSpy.mockRestore() + }) + + it('should use default timeout when negative timeout value is provided', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + const args = ['https://example.com', '--auth-timeout', '-30'] + const result = await parseCommandLineArgs(args, mockUsage) + + expect(result.authTimeoutMs).toBe(30000) + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining('Warning: Ignoring invalid auth timeout value: -30. Must be a positive number.') + ) + + consoleSpy.mockRestore() + }) + + it('should use default timeout when zero timeout value is provided', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + const args = ['https://example.com', '--auth-timeout', '0'] + const result = await parseCommandLineArgs(args, mockUsage) + + expect(result.authTimeoutMs).toBe(30000) + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining('Warning: Ignoring invalid auth timeout value: 0. Must be a positive number.') + ) + + consoleSpy.mockRestore() + }) + + it('should use default timeout when --auth-timeout flag has no value', async () => { + const args = ['https://example.com', '--auth-timeout'] + const result = await parseCommandLineArgs(args, mockUsage) + + expect(result.authTimeoutMs).toBe(30000) + }) + + it('should log when using custom timeout', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + const args = ['https://example.com', '--auth-timeout', '45'] + const result = await parseCommandLineArgs(args, mockUsage) + + expect(result.authTimeoutMs).toBe(45000) + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining('Using auth callback timeout: 45 seconds') + ) + + consoleSpy.mockRestore() + }) + }) +}) + +describe('setupOAuthCallbackServerWithLongPoll', () => { + let server: any + let events: EventEmitter + + beforeEach(() => { + events = new EventEmitter() + }) + + afterEach(() => { + if (server) { + server.close() + server = null + } + }) + + it('should use custom timeout when authTimeoutMs is provided', async () => { + const customTimeout = 5000 + const result = setupOAuthCallbackServerWithLongPoll({ + port: 0, // Use any available port + path: '/oauth/callback', + events, + authTimeoutMs: customTimeout + }) + + server = result.server + + // Test that the server was created + expect(server).toBeDefined() + expect(typeof result.waitForAuthCode).toBe('function') + }) + + it('should use default timeout when authTimeoutMs is not provided', async () => { + const result = setupOAuthCallbackServerWithLongPoll({ + port: 0, // Use any available port + path: '/oauth/callback', + events + }) + + server = result.server + + // Test that the server was created with defaults + expect(server).toBeDefined() + expect(typeof result.waitForAuthCode).toBe('function') + }) +}) diff --git a/src/lib/utils.ts b/src/lib/utils.ts index 5aa4c84..cb8f016 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -391,7 +391,7 @@ export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServe const longPollTimeout = setTimeout(() => { log('Long poll timeout reached, responding with 202') res.status(202).send('Authentication in progress') - }, 30000) + }, options.authTimeoutMs || 30000) // If auth completes while we're waiting, send the response immediately authCompletedPromise @@ -617,6 +617,19 @@ export async function parseCommandLineArgs(args: string[], usage: string) { log(`Using authorize resource: ${authorizeResource}`) } + // Parse auth timeout + let authTimeoutMs = 30000 // Default 30 seconds + const authTimeoutIndex = args.indexOf('--auth-timeout') + if (authTimeoutIndex !== -1 && authTimeoutIndex < args.length - 1) { + const timeoutSeconds = parseInt(args[authTimeoutIndex + 1], 10) + if (!isNaN(timeoutSeconds) && timeoutSeconds > 0) { + authTimeoutMs = timeoutSeconds * 1000 + log(`Using auth callback timeout: ${timeoutSeconds} seconds`) + } else { + log(`Warning: Ignoring invalid auth timeout value: ${args[authTimeoutIndex + 1]}. Must be a positive number.`) + } + } + if (!serverUrl) { log(usage) process.exit(1) @@ -691,6 +704,7 @@ export async function parseCommandLineArgs(args: string[], usage: string) { staticOAuthClientMetadata, staticOAuthClientInfo, authorizeResource, + authTimeoutMs, } } diff --git a/src/proxy.ts b/src/proxy.ts index 6627bf8..271917a 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -36,6 +36,7 @@ async function runProxy( staticOAuthClientMetadata: StaticOAuthClientMetadata, staticOAuthClientInfo: StaticOAuthClientInformationFull, authorizeResource: string, + authTimeoutMs: number, ) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -44,7 +45,7 @@ async function runProxy( const serverUrlHash = getServerUrlHash(serverUrl) // Create a lazy auth coordinator - const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events) + const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs) // Create the OAuth client provider const authProvider = new NodeOAuthClientProvider({ @@ -155,6 +156,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts { return runProxy( serverUrl, @@ -165,6 +167,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts