Skip to content

feat(auth-timeout): make callback timeout configurable #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ To bypass authentication, or to emit custom headers on all requests to your remo
]
```

* To change the timeout for the OAuth callback (by default `30` seconds), add the `--auth-timeout` flag with a value in seconds. This is useful if the authentication process on the server side takes a long time.

```json
"args": [
"mcp-remote",
"https://remote.mcp.server/sse",
"--auth-timeout",
"60"
]
```

### Transport Strategies

MCP Remote supports different transport strategies when connecting to an MCP server. This allows you to control whether it uses Server-Sent Events (SSE) or HTTP transport, and in what order it tries them.
Expand Down
7 changes: 4 additions & 3 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async function runClient(
host: string,
staticOAuthClientMetadata: StaticOAuthClientMetadata,
staticOAuthClientInfo: StaticOAuthClientInformationFull,
authTimeoutMs: number,
) {
// Set up event emitter for auth flow
const events = new EventEmitter()
Expand All @@ -44,7 +45,7 @@ async function runClient(
const serverUrlHash = getServerUrlHash(serverUrl)

// Create a lazy auth coordinator
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs)

// Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({
Expand Down Expand Up @@ -159,8 +160,8 @@ async function runClient(

// Parse command-line arguments and run the client
parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx client.ts <https://server-url> [callback-port] [--debug]')
.then(({ serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo }) => {
return runClient(serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo)
.then(({ serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo, authTimeoutMs }) => {
return runClient(serverUrl, callbackPort, headers, transportStrategy, host, staticOAuthClientMetadata, staticOAuthClientInfo, authTimeoutMs)
})
.catch((error) => {
console.error('Fatal error:', error)
Expand Down
11 changes: 9 additions & 2 deletions src/lib/coordination.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ export async function waitForAuthentication(port: number): Promise<boolean> {
* @param events The event emitter to use for signaling
* @returns An AuthCoordinator object with an initializeAuth method
*/
export function createLazyAuthCoordinator(serverUrlHash: string, callbackPort: number, events: EventEmitter): AuthCoordinator {
export function createLazyAuthCoordinator(
serverUrlHash: string,
callbackPort: number,
events: EventEmitter,
authTimeoutMs: number,
): AuthCoordinator {
let authState: { server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean } | null = null

return {
Expand All @@ -144,7 +149,7 @@ export function createLazyAuthCoordinator(serverUrlHash: string, callbackPort: n
if (DEBUG) debugLog('Initializing auth coordination on-demand', { serverUrlHash, callbackPort })

// Initialize auth using the existing coordinateAuth logic
authState = await coordinateAuth(serverUrlHash, callbackPort, events)
authState = await coordinateAuth(serverUrlHash, callbackPort, events, authTimeoutMs)
if (DEBUG) debugLog('Auth coordination completed', { skipBrowserAuth: authState.skipBrowserAuth })
return authState
},
Expand All @@ -162,6 +167,7 @@ export async function coordinateAuth(
serverUrlHash: string,
callbackPort: number,
events: EventEmitter,
authTimeoutMs: number,
): Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }> {
if (DEBUG) debugLog('Coordinating authentication', { serverUrlHash, callbackPort })

Expand Down Expand Up @@ -228,6 +234,7 @@ export async function coordinateAuth(
port: callbackPort,
path: '/oauth/callback',
events,
authTimeoutMs,
})

// Get the actual port the server is running on
Expand Down
2 changes: 2 additions & 0 deletions src/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ export interface OAuthCallbackServerOptions {
path: string
/** Event emitter to signal when auth code is received */
events: EventEmitter
/** Timeout in milliseconds for the auth callback server's long poll */
authTimeoutMs?: number
}

// optional tatic OAuth client information
Expand Down
152 changes: 151 additions & 1 deletion src/lib/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,153 @@
import { describe, it, expect } from 'vitest'
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { parseCommandLineArgs, setupOAuthCallbackServerWithLongPoll } from './utils'
import { EventEmitter } from 'events'
import express from 'express'

// All sanitizeUrl tests have been moved to the strict-url-sanitise package

describe('parseCommandLineArgs', () => {
const mockUsage = 'Usage: test <url>'
const mockExit = vi.spyOn(process, 'exit').mockImplementation(() => {
throw new Error('process.exit called')
})

beforeEach(() => {
vi.clearAllMocks()
})

afterEach(() => {
mockExit.mockReset()
})

describe('--auth-timeout parsing', () => {
it('should use default timeout of 30000ms when no --auth-timeout flag is provided', async () => {
const args = ['https://example.com']
const result = await parseCommandLineArgs(args, mockUsage)

expect(result.authTimeoutMs).toBe(30000)
})

it('should parse valid timeout in seconds and convert to milliseconds', async () => {
const args = ['https://example.com', '--auth-timeout', '60']
const result = await parseCommandLineArgs(args, mockUsage)

expect(result.authTimeoutMs).toBe(60000)
})

it('should parse another valid timeout value', async () => {
const args = ['https://example.com', '--auth-timeout', '120']
const result = await parseCommandLineArgs(args, mockUsage)

expect(result.authTimeoutMs).toBe(120000)
})

it('should use default timeout when invalid timeout value is provided', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})

const args = ['https://example.com', '--auth-timeout', 'invalid']
const result = await parseCommandLineArgs(args, mockUsage)

expect(result.authTimeoutMs).toBe(30000)
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining('Warning: Ignoring invalid auth timeout value: invalid. Must be a positive number.')
)

consoleSpy.mockRestore()
})

it('should use default timeout when negative timeout value is provided', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})

const args = ['https://example.com', '--auth-timeout', '-30']
const result = await parseCommandLineArgs(args, mockUsage)

expect(result.authTimeoutMs).toBe(30000)
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining('Warning: Ignoring invalid auth timeout value: -30. Must be a positive number.')
)

consoleSpy.mockRestore()
})

it('should use default timeout when zero timeout value is provided', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})

const args = ['https://example.com', '--auth-timeout', '0']
const result = await parseCommandLineArgs(args, mockUsage)

expect(result.authTimeoutMs).toBe(30000)
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining('Warning: Ignoring invalid auth timeout value: 0. Must be a positive number.')
)

consoleSpy.mockRestore()
})

it('should use default timeout when --auth-timeout flag has no value', async () => {
const args = ['https://example.com', '--auth-timeout']
const result = await parseCommandLineArgs(args, mockUsage)

expect(result.authTimeoutMs).toBe(30000)
})

it('should log when using custom timeout', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})

const args = ['https://example.com', '--auth-timeout', '45']
const result = await parseCommandLineArgs(args, mockUsage)

expect(result.authTimeoutMs).toBe(45000)
expect(consoleSpy).toHaveBeenCalledWith(
expect.stringContaining('Using auth callback timeout: 45 seconds')
)

consoleSpy.mockRestore()
})
})
})

describe('setupOAuthCallbackServerWithLongPoll', () => {
let server: any
let events: EventEmitter

beforeEach(() => {
events = new EventEmitter()
})

afterEach(() => {
if (server) {
server.close()
server = null
}
})

it('should use custom timeout when authTimeoutMs is provided', async () => {
const customTimeout = 5000
const result = setupOAuthCallbackServerWithLongPoll({
port: 0, // Use any available port
path: '/oauth/callback',
events,
authTimeoutMs: customTimeout
})

server = result.server

// Test that the server was created
expect(server).toBeDefined()
expect(typeof result.waitForAuthCode).toBe('function')
})

it('should use default timeout when authTimeoutMs is not provided', async () => {
const result = setupOAuthCallbackServerWithLongPoll({
port: 0, // Use any available port
path: '/oauth/callback',
events
})

server = result.server

// Test that the server was created with defaults
expect(server).toBeDefined()
expect(typeof result.waitForAuthCode).toBe('function')
})
})
16 changes: 15 additions & 1 deletion src/lib/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServe
const longPollTimeout = setTimeout(() => {
log('Long poll timeout reached, responding with 202')
res.status(202).send('Authentication in progress')
}, 30000)
}, options.authTimeoutMs || 30000)

// If auth completes while we're waiting, send the response immediately
authCompletedPromise
Expand Down Expand Up @@ -617,6 +617,19 @@ export async function parseCommandLineArgs(args: string[], usage: string) {
log(`Using authorize resource: ${authorizeResource}`)
}

// Parse auth timeout
let authTimeoutMs = 30000 // Default 30 seconds
const authTimeoutIndex = args.indexOf('--auth-timeout')
if (authTimeoutIndex !== -1 && authTimeoutIndex < args.length - 1) {
const timeoutSeconds = parseInt(args[authTimeoutIndex + 1], 10)
if (!isNaN(timeoutSeconds) && timeoutSeconds > 0) {
authTimeoutMs = timeoutSeconds * 1000
log(`Using auth callback timeout: ${timeoutSeconds} seconds`)
} else {
log(`Warning: Ignoring invalid auth timeout value: ${args[authTimeoutIndex + 1]}. Must be a positive number.`)
}
}

if (!serverUrl) {
log(usage)
process.exit(1)
Expand Down Expand Up @@ -691,6 +704,7 @@ export async function parseCommandLineArgs(args: string[], usage: string) {
staticOAuthClientMetadata,
staticOAuthClientInfo,
authorizeResource,
authTimeoutMs,
}
}

Expand Down
5 changes: 4 additions & 1 deletion src/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async function runProxy(
staticOAuthClientMetadata: StaticOAuthClientMetadata,
staticOAuthClientInfo: StaticOAuthClientInformationFull,
authorizeResource: string,
authTimeoutMs: number,
) {
// Set up event emitter for auth flow
const events = new EventEmitter()
Expand All @@ -44,7 +45,7 @@ async function runProxy(
const serverUrlHash = getServerUrlHash(serverUrl)

// Create a lazy auth coordinator
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events)
const authCoordinator = createLazyAuthCoordinator(serverUrlHash, callbackPort, events, authTimeoutMs)

// Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({
Expand Down Expand Up @@ -155,6 +156,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts <https://se
staticOAuthClientMetadata,
staticOAuthClientInfo,
authorizeResource,
authTimeoutMs,
}) => {
return runProxy(
serverUrl,
Expand All @@ -165,6 +167,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts <https://se
staticOAuthClientMetadata,
staticOAuthClientInfo,
authorizeResource,
authTimeoutMs,
)
},
)
Expand Down