@@ -453,4 +453,264 @@ describe('SSEServerTransport', () => {
453
453
expect . stringContaining ( `data: /messages?sessionId=${ transport . sessionId } ` ) ) ;
454
454
} ) ;
455
455
} ) ;
456
- } ) ;
456
+
457
+ describe ( 'DNS rebinding protection' , ( ) => {
458
+ beforeEach ( ( ) => {
459
+ jest . clearAllMocks ( ) ;
460
+ } ) ;
461
+
462
+ describe ( 'Host header validation' , ( ) => {
463
+ it ( 'should accept requests with allowed host headers' , async ( ) => {
464
+ const mockRes = createMockResponse ( ) ;
465
+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
466
+ allowedHosts : [ 'localhost:3000' , 'example.com' ] ,
467
+ enableDnsRebindingProtection : true ,
468
+ } ) ;
469
+ await transport . start ( ) ;
470
+
471
+ const mockReq = createMockRequest ( {
472
+ headers : {
473
+ host : 'localhost:3000' ,
474
+ 'content-type' : 'application/json' ,
475
+ }
476
+ } ) ;
477
+ const mockHandleRes = createMockResponse ( ) ;
478
+
479
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
480
+
481
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
482
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
483
+ } ) ;
484
+
485
+ it ( 'should reject requests with disallowed host headers' , async ( ) => {
486
+ const mockRes = createMockResponse ( ) ;
487
+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
488
+ allowedHosts : [ 'localhost:3000' ] ,
489
+ enableDnsRebindingProtection : true ,
490
+ } ) ;
491
+ await transport . start ( ) ;
492
+
493
+ const mockReq = createMockRequest ( {
494
+ headers : {
495
+ host : 'evil.com' ,
496
+ 'content-type' : 'application/json' ,
497
+ }
498
+ } ) ;
499
+ const mockHandleRes = createMockResponse ( ) ;
500
+
501
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
502
+
503
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
504
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Invalid Host header: evil.com' ) ;
505
+ } ) ;
506
+
507
+ it ( 'should reject requests without host header when allowedHosts is configured' , async ( ) => {
508
+ const mockRes = createMockResponse ( ) ;
509
+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
510
+ allowedHosts : [ 'localhost:3000' ] ,
511
+ enableDnsRebindingProtection : true ,
512
+ } ) ;
513
+ await transport . start ( ) ;
514
+
515
+ const mockReq = createMockRequest ( {
516
+ headers : {
517
+ 'content-type' : 'application/json' ,
518
+ }
519
+ } ) ;
520
+ const mockHandleRes = createMockResponse ( ) ;
521
+
522
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
523
+
524
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
525
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Invalid Host header: undefined' ) ;
526
+ } ) ;
527
+ } ) ;
528
+
529
+ describe ( 'Origin header validation' , ( ) => {
530
+ it ( 'should accept requests with allowed origin headers' , async ( ) => {
531
+ const mockRes = createMockResponse ( ) ;
532
+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
533
+ allowedOrigins : [ 'http://localhost:3000' , 'https://example.com' ] ,
534
+ enableDnsRebindingProtection : true ,
535
+ } ) ;
536
+ await transport . start ( ) ;
537
+
538
+ const mockReq = createMockRequest ( {
539
+ headers : {
540
+ origin : 'http://localhost:3000' ,
541
+ 'content-type' : 'application/json' ,
542
+ }
543
+ } ) ;
544
+ const mockHandleRes = createMockResponse ( ) ;
545
+
546
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
547
+
548
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
549
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
550
+ } ) ;
551
+
552
+ it ( 'should reject requests with disallowed origin headers' , async ( ) => {
553
+ const mockRes = createMockResponse ( ) ;
554
+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
555
+ allowedOrigins : [ 'http://localhost:3000' ] ,
556
+ enableDnsRebindingProtection : true ,
557
+ } ) ;
558
+ await transport . start ( ) ;
559
+
560
+ const mockReq = createMockRequest ( {
561
+ headers : {
562
+ origin : 'http://evil.com' ,
563
+ 'content-type' : 'application/json' ,
564
+ }
565
+ } ) ;
566
+ const mockHandleRes = createMockResponse ( ) ;
567
+
568
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
569
+
570
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
571
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Invalid Origin header: http://evil.com' ) ;
572
+ } ) ;
573
+ } ) ;
574
+
575
+ describe ( 'Content-Type validation' , ( ) => {
576
+ it ( 'should accept requests with application/json content-type' , async ( ) => {
577
+ const mockRes = createMockResponse ( ) ;
578
+ const transport = new SSEServerTransport ( '/messages' , mockRes ) ;
579
+ await transport . start ( ) ;
580
+
581
+ const mockReq = createMockRequest ( {
582
+ headers : {
583
+ 'content-type' : 'application/json' ,
584
+ }
585
+ } ) ;
586
+ const mockHandleRes = createMockResponse ( ) ;
587
+
588
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
589
+
590
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
591
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
592
+ } ) ;
593
+
594
+ it ( 'should accept requests with application/json with charset' , async ( ) => {
595
+ const mockRes = createMockResponse ( ) ;
596
+ const transport = new SSEServerTransport ( '/messages' , mockRes ) ;
597
+ await transport . start ( ) ;
598
+
599
+ const mockReq = createMockRequest ( {
600
+ headers : {
601
+ 'content-type' : 'application/json; charset=utf-8' ,
602
+ }
603
+ } ) ;
604
+ const mockHandleRes = createMockResponse ( ) ;
605
+
606
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
607
+
608
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
609
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
610
+ } ) ;
611
+
612
+ it ( 'should reject requests with non-application/json content-type when protection is enabled' , async ( ) => {
613
+ const mockRes = createMockResponse ( ) ;
614
+ const transport = new SSEServerTransport ( '/messages' , mockRes ) ;
615
+ await transport . start ( ) ;
616
+
617
+ const mockReq = createMockRequest ( {
618
+ headers : {
619
+ 'content-type' : 'text/plain' ,
620
+ }
621
+ } ) ;
622
+ const mockHandleRes = createMockResponse ( ) ;
623
+
624
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
625
+
626
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 400 ) ;
627
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Error: Unsupported content-type: text/plain' ) ;
628
+ } ) ;
629
+ } ) ;
630
+
631
+ describe ( 'enableDnsRebindingProtection option' , ( ) => {
632
+ it ( 'should skip all validations when enableDnsRebindingProtection is false' , async ( ) => {
633
+ const mockRes = createMockResponse ( ) ;
634
+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
635
+ allowedHosts : [ 'localhost:3000' ] ,
636
+ allowedOrigins : [ 'http://localhost:3000' ] ,
637
+ enableDnsRebindingProtection : false ,
638
+ } ) ;
639
+ await transport . start ( ) ;
640
+
641
+ const mockReq = createMockRequest ( {
642
+ headers : {
643
+ host : 'evil.com' ,
644
+ origin : 'http://evil.com' ,
645
+ 'content-type' : 'text/plain' ,
646
+ }
647
+ } ) ;
648
+ const mockHandleRes = createMockResponse ( ) ;
649
+
650
+ await transport . handlePostMessage ( mockReq , mockHandleRes , { jsonrpc : '2.0' , method : 'test' } ) ;
651
+
652
+ // Should pass even with invalid headers because protection is disabled
653
+ expect ( mockHandleRes . writeHead ) . toHaveBeenCalledWith ( 400 ) ;
654
+ // The error should be from content-type parsing, not DNS rebinding protection
655
+ expect ( mockHandleRes . end ) . toHaveBeenCalledWith ( 'Error: Unsupported content-type: text/plain' ) ;
656
+ } ) ;
657
+ } ) ;
658
+
659
+ describe ( 'Combined validations' , ( ) => {
660
+ it ( 'should validate both host and origin when both are configured' , async ( ) => {
661
+ const mockRes = createMockResponse ( ) ;
662
+ const transport = new SSEServerTransport ( '/messages' , mockRes , {
663
+ allowedHosts : [ 'localhost:3000' ] ,
664
+ allowedOrigins : [ 'http://localhost:3000' ] ,
665
+ enableDnsRebindingProtection : true ,
666
+ } ) ;
667
+ await transport . start ( ) ;
668
+
669
+ // Valid host, invalid origin
670
+ const mockReq1 = createMockRequest ( {
671
+ headers : {
672
+ host : 'localhost:3000' ,
673
+ origin : 'http://evil.com' ,
674
+ 'content-type' : 'application/json' ,
675
+ }
676
+ } ) ;
677
+ const mockHandleRes1 = createMockResponse ( ) ;
678
+
679
+ await transport . handlePostMessage ( mockReq1 , mockHandleRes1 , { jsonrpc : '2.0' , method : 'test' } ) ;
680
+
681
+ expect ( mockHandleRes1 . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
682
+ expect ( mockHandleRes1 . end ) . toHaveBeenCalledWith ( 'Invalid Origin header: http://evil.com' ) ;
683
+
684
+ // Invalid host, valid origin
685
+ const mockReq2 = createMockRequest ( {
686
+ headers : {
687
+ host : 'evil.com' ,
688
+ origin : 'http://localhost:3000' ,
689
+ 'content-type' : 'application/json' ,
690
+ }
691
+ } ) ;
692
+ const mockHandleRes2 = createMockResponse ( ) ;
693
+
694
+ await transport . handlePostMessage ( mockReq2 , mockHandleRes2 , { jsonrpc : '2.0' , method : 'test' } ) ;
695
+
696
+ expect ( mockHandleRes2 . writeHead ) . toHaveBeenCalledWith ( 403 ) ;
697
+ expect ( mockHandleRes2 . end ) . toHaveBeenCalledWith ( 'Invalid Host header: evil.com' ) ;
698
+
699
+ // Both valid
700
+ const mockReq3 = createMockRequest ( {
701
+ headers : {
702
+ host : 'localhost:3000' ,
703
+ origin : 'http://localhost:3000' ,
704
+ 'content-type' : 'application/json' ,
705
+ }
706
+ } ) ;
707
+ const mockHandleRes3 = createMockResponse ( ) ;
708
+
709
+ await transport . handlePostMessage ( mockReq3 , mockHandleRes3 , { jsonrpc : '2.0' , method : 'test' } ) ;
710
+
711
+ expect ( mockHandleRes3 . writeHead ) . toHaveBeenCalledWith ( 202 ) ;
712
+ expect ( mockHandleRes3 . end ) . toHaveBeenCalledWith ( 'Accepted' ) ;
713
+ } ) ;
714
+ } ) ;
715
+ } ) ;
716
+ } ) ;
0 commit comments