Skip to content

Commit 26a98de

Browse files
authored
Merge pull request #565 from ddworken/main
Implement DNS Rebinding Protections per spec
2 parents 1e52f38 + 8bc7374 commit 26a98de

File tree

5 files changed

+701
-4
lines changed

5 files changed

+701
-4
lines changed

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,11 @@ app.post('/mcp', async (req, res) => {
444444
onsessioninitialized: (sessionId) => {
445445
// Store the transport by session ID
446446
transports[sessionId] = transport;
447-
}
447+
},
448+
// DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server
449+
// locally, make sure to set:
450+
// enableDnsRebindingProtection: true,
451+
// allowedHosts: ['127.0.0.1'],
448452
});
449453

450454
// Clean up transport when closed
@@ -596,6 +600,22 @@ This stateless approach is useful for:
596600
- RESTful scenarios where each request is independent
597601
- Horizontally scaled deployments without shared session state
598602

603+
#### DNS Rebinding Protection
604+
605+
The Streamable HTTP transport includes DNS rebinding protection to prevent security vulnerabilities. By default, this protection is **disabled** for backwards compatibility.
606+
607+
**Important**: If you are running this server locally, enable DNS rebinding protection:
608+
609+
```typescript
610+
const transport = new StreamableHTTPServerTransport({
611+
sessionIdGenerator: () => randomUUID(),
612+
enableDnsRebindingProtection: true,
613+
614+
allowedHosts: ['127.0.0.1', ...],
615+
allowedOrigins: ['https://yourdomain.com', 'https://www.yourdomain.com']
616+
});
617+
```
618+
599619
### Testing and Debugging
600620

601621
To test your server, you can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector). See its README for more information.

src/server/sse.test.ts

Lines changed: 261 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,4 +453,264 @@ describe('SSEServerTransport', () => {
453453
expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`));
454454
});
455455
});
456-
});
456+
457+
describe('DNS rebinding protection', () => {
458+
beforeEach(() => {
459+
jest.clearAllMocks();
460+
});
461+
462+
describe('Host header validation', () => {
463+
it('should accept requests with allowed host headers', async () => {
464+
const mockRes = createMockResponse();
465+
const transport = new SSEServerTransport('/messages', mockRes, {
466+
allowedHosts: ['localhost:3000', 'example.com'],
467+
enableDnsRebindingProtection: true,
468+
});
469+
await transport.start();
470+
471+
const mockReq = createMockRequest({
472+
headers: {
473+
host: 'localhost:3000',
474+
'content-type': 'application/json',
475+
}
476+
});
477+
const mockHandleRes = createMockResponse();
478+
479+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
480+
481+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
482+
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
483+
});
484+
485+
it('should reject requests with disallowed host headers', async () => {
486+
const mockRes = createMockResponse();
487+
const transport = new SSEServerTransport('/messages', mockRes, {
488+
allowedHosts: ['localhost:3000'],
489+
enableDnsRebindingProtection: true,
490+
});
491+
await transport.start();
492+
493+
const mockReq = createMockRequest({
494+
headers: {
495+
host: 'evil.com',
496+
'content-type': 'application/json',
497+
}
498+
});
499+
const mockHandleRes = createMockResponse();
500+
501+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
502+
503+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
504+
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com');
505+
});
506+
507+
it('should reject requests without host header when allowedHosts is configured', async () => {
508+
const mockRes = createMockResponse();
509+
const transport = new SSEServerTransport('/messages', mockRes, {
510+
allowedHosts: ['localhost:3000'],
511+
enableDnsRebindingProtection: true,
512+
});
513+
await transport.start();
514+
515+
const mockReq = createMockRequest({
516+
headers: {
517+
'content-type': 'application/json',
518+
}
519+
});
520+
const mockHandleRes = createMockResponse();
521+
522+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
523+
524+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
525+
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined');
526+
});
527+
});
528+
529+
describe('Origin header validation', () => {
530+
it('should accept requests with allowed origin headers', async () => {
531+
const mockRes = createMockResponse();
532+
const transport = new SSEServerTransport('/messages', mockRes, {
533+
allowedOrigins: ['http://localhost:3000', 'https://example.com'],
534+
enableDnsRebindingProtection: true,
535+
});
536+
await transport.start();
537+
538+
const mockReq = createMockRequest({
539+
headers: {
540+
origin: 'http://localhost:3000',
541+
'content-type': 'application/json',
542+
}
543+
});
544+
const mockHandleRes = createMockResponse();
545+
546+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
547+
548+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
549+
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
550+
});
551+
552+
it('should reject requests with disallowed origin headers', async () => {
553+
const mockRes = createMockResponse();
554+
const transport = new SSEServerTransport('/messages', mockRes, {
555+
allowedOrigins: ['http://localhost:3000'],
556+
enableDnsRebindingProtection: true,
557+
});
558+
await transport.start();
559+
560+
const mockReq = createMockRequest({
561+
headers: {
562+
origin: 'http://evil.com',
563+
'content-type': 'application/json',
564+
}
565+
});
566+
const mockHandleRes = createMockResponse();
567+
568+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
569+
570+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
571+
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');
572+
});
573+
});
574+
575+
describe('Content-Type validation', () => {
576+
it('should accept requests with application/json content-type', async () => {
577+
const mockRes = createMockResponse();
578+
const transport = new SSEServerTransport('/messages', mockRes);
579+
await transport.start();
580+
581+
const mockReq = createMockRequest({
582+
headers: {
583+
'content-type': 'application/json',
584+
}
585+
});
586+
const mockHandleRes = createMockResponse();
587+
588+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
589+
590+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
591+
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
592+
});
593+
594+
it('should accept requests with application/json with charset', async () => {
595+
const mockRes = createMockResponse();
596+
const transport = new SSEServerTransport('/messages', mockRes);
597+
await transport.start();
598+
599+
const mockReq = createMockRequest({
600+
headers: {
601+
'content-type': 'application/json; charset=utf-8',
602+
}
603+
});
604+
const mockHandleRes = createMockResponse();
605+
606+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
607+
608+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
609+
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
610+
});
611+
612+
it('should reject requests with non-application/json content-type when protection is enabled', async () => {
613+
const mockRes = createMockResponse();
614+
const transport = new SSEServerTransport('/messages', mockRes);
615+
await transport.start();
616+
617+
const mockReq = createMockRequest({
618+
headers: {
619+
'content-type': 'text/plain',
620+
}
621+
});
622+
const mockHandleRes = createMockResponse();
623+
624+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
625+
626+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
627+
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
628+
});
629+
});
630+
631+
describe('enableDnsRebindingProtection option', () => {
632+
it('should skip all validations when enableDnsRebindingProtection is false', async () => {
633+
const mockRes = createMockResponse();
634+
const transport = new SSEServerTransport('/messages', mockRes, {
635+
allowedHosts: ['localhost:3000'],
636+
allowedOrigins: ['http://localhost:3000'],
637+
enableDnsRebindingProtection: false,
638+
});
639+
await transport.start();
640+
641+
const mockReq = createMockRequest({
642+
headers: {
643+
host: 'evil.com',
644+
origin: 'http://evil.com',
645+
'content-type': 'text/plain',
646+
}
647+
});
648+
const mockHandleRes = createMockResponse();
649+
650+
await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });
651+
652+
// Should pass even with invalid headers because protection is disabled
653+
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
654+
// The error should be from content-type parsing, not DNS rebinding protection
655+
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
656+
});
657+
});
658+
659+
describe('Combined validations', () => {
660+
it('should validate both host and origin when both are configured', async () => {
661+
const mockRes = createMockResponse();
662+
const transport = new SSEServerTransport('/messages', mockRes, {
663+
allowedHosts: ['localhost:3000'],
664+
allowedOrigins: ['http://localhost:3000'],
665+
enableDnsRebindingProtection: true,
666+
});
667+
await transport.start();
668+
669+
// Valid host, invalid origin
670+
const mockReq1 = createMockRequest({
671+
headers: {
672+
host: 'localhost:3000',
673+
origin: 'http://evil.com',
674+
'content-type': 'application/json',
675+
}
676+
});
677+
const mockHandleRes1 = createMockResponse();
678+
679+
await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' });
680+
681+
expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403);
682+
expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');
683+
684+
// Invalid host, valid origin
685+
const mockReq2 = createMockRequest({
686+
headers: {
687+
host: 'evil.com',
688+
origin: 'http://localhost:3000',
689+
'content-type': 'application/json',
690+
}
691+
});
692+
const mockHandleRes2 = createMockResponse();
693+
694+
await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' });
695+
696+
expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403);
697+
expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com');
698+
699+
// Both valid
700+
const mockReq3 = createMockRequest({
701+
headers: {
702+
host: 'localhost:3000',
703+
origin: 'http://localhost:3000',
704+
'content-type': 'application/json',
705+
}
706+
});
707+
const mockHandleRes3 = createMockResponse();
708+
709+
await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' });
710+
711+
expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202);
712+
expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted');
713+
});
714+
});
715+
});
716+
});

src/server/sse.ts

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,29 @@ import { URL } from 'url';
99

1010
const MAXIMUM_MESSAGE_SIZE = "4mb";
1111

12+
/**
13+
* Configuration options for SSEServerTransport.
14+
*/
15+
export interface SSEServerTransportOptions {
16+
/**
17+
* List of allowed host header values for DNS rebinding protection.
18+
* If not specified, host validation is disabled.
19+
*/
20+
allowedHosts?: string[];
21+
22+
/**
23+
* List of allowed origin header values for DNS rebinding protection.
24+
* If not specified, origin validation is disabled.
25+
*/
26+
allowedOrigins?: string[];
27+
28+
/**
29+
* Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured).
30+
* Default is false for backwards compatibility.
31+
*/
32+
enableDnsRebindingProtection?: boolean;
33+
}
34+
1235
/**
1336
* Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests.
1437
*
@@ -17,6 +40,7 @@ const MAXIMUM_MESSAGE_SIZE = "4mb";
1740
export class SSEServerTransport implements Transport {
1841
private _sseResponse?: ServerResponse;
1942
private _sessionId: string;
43+
private _options: SSEServerTransportOptions;
2044
onclose?: () => void;
2145
onerror?: (error: Error) => void;
2246
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
@@ -27,8 +51,39 @@ export class SSEServerTransport implements Transport {
2751
constructor(
2852
private _endpoint: string,
2953
private res: ServerResponse,
54+
options?: SSEServerTransportOptions,
3055
) {
3156
this._sessionId = randomUUID();
57+
this._options = options || {enableDnsRebindingProtection: false};
58+
}
59+
60+
/**
61+
* Validates request headers for DNS rebinding protection.
62+
* @returns Error message if validation fails, undefined if validation passes.
63+
*/
64+
private validateRequestHeaders(req: IncomingMessage): string | undefined {
65+
// Skip validation if protection is not enabled
66+
if (!this._options.enableDnsRebindingProtection) {
67+
return undefined;
68+
}
69+
70+
// Validate Host header if allowedHosts is configured
71+
if (this._options.allowedHosts && this._options.allowedHosts.length > 0) {
72+
const hostHeader = req.headers.host;
73+
if (!hostHeader || !this._options.allowedHosts.includes(hostHeader)) {
74+
return `Invalid Host header: ${hostHeader}`;
75+
}
76+
}
77+
78+
// Validate Origin header if allowedOrigins is configured
79+
if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) {
80+
const originHeader = req.headers.origin;
81+
if (!originHeader || !this._options.allowedOrigins.includes(originHeader)) {
82+
return `Invalid Origin header: ${originHeader}`;
83+
}
84+
}
85+
86+
return undefined;
3287
}
3388

3489
/**
@@ -85,6 +140,15 @@ export class SSEServerTransport implements Transport {
85140
res.writeHead(500).end(message);
86141
throw new Error(message);
87142
}
143+
144+
// Validate request headers for DNS rebinding protection
145+
const validationError = this.validateRequestHeaders(req);
146+
if (validationError) {
147+
res.writeHead(403).end(validationError);
148+
this.onerror?.(new Error(validationError));
149+
return;
150+
}
151+
88152
const authInfo: AuthInfo | undefined = req.auth;
89153
const requestInfo: RequestInfo = { headers: req.headers };
90154

0 commit comments

Comments
 (0)