@@ -340,6 +340,21 @@ class HandshakePattern {
340
340
}
341
341
342
342
record MessagePattern (NoiseHandshake .Role sender , Token [] tokens ) {
343
+
344
+ MessagePattern withAddedToken (final Token token , final int insertionIndex ) {
345
+ if (insertionIndex < 0 || insertionIndex >= this .tokens ().length + 1 ) {
346
+ throw new IllegalArgumentException ("Illegal insertion index" );
347
+ }
348
+
349
+ final Token [] modifiedTokens = new Token [this .tokens ().length + 1 ];
350
+ System .arraycopy (this .tokens (), 0 , modifiedTokens , 0 , insertionIndex );
351
+ modifiedTokens [insertionIndex ] = token ;
352
+ System .arraycopy (this .tokens (), insertionIndex , modifiedTokens ,
353
+ insertionIndex + 1 , this .tokens ().length - insertionIndex );
354
+
355
+ return new MessagePattern (this .sender (), modifiedTokens );
356
+ }
357
+
343
358
@ Override
344
359
public String toString () {
345
360
final String prefix = switch (sender ()) {
@@ -375,18 +390,24 @@ enum Token {
375
390
ES ,
376
391
SE ,
377
392
SS ,
378
- PSK ;
393
+ PSK ,
394
+ E1 ,
395
+ EKEM1 ;
379
396
380
397
static Token fromString (final String string ) {
381
- return switch (string ) {
382
- case "e" , "E" -> E ;
383
- case "s" , "S" -> S ;
384
- case "ee" , "EE" -> EE ;
385
- case "es" , "ES" -> ES ;
386
- case "se" , "SE" -> SE ;
387
- case "ss" , "SS" -> SS ;
388
- case "psk" , "PSK" -> PSK ;
389
- default -> throw new IllegalArgumentException ("Unrecognized token: " + string );
398
+ for (final Token token : Token .values ()) {
399
+ if (token .name ().equalsIgnoreCase (string )) {
400
+ return token ;
401
+ }
402
+ }
403
+
404
+ throw new IllegalArgumentException ("Unrecognized token: " + string );
405
+ }
406
+
407
+ boolean isKeyAgreementToken () {
408
+ return switch (this ) {
409
+ case EE , ES , SE , SS -> true ;
410
+ default -> false ;
390
411
};
391
412
}
392
413
}
@@ -482,6 +503,8 @@ HandshakePattern withModifier(final String modifier) {
482
503
modifiedMessagePatterns = getPatternsWithFallbackModifier ();
483
504
} else if (modifier .startsWith ("psk" )) {
484
505
modifiedMessagePatterns = getPatternsWithPskModifier (modifier );
506
+ } else if ("hfs" .equals (modifier )) {
507
+ modifiedMessagePatterns = getPatternsWithHfsModifier ();
485
508
} else {
486
509
throw new IllegalArgumentException ("Unrecognized modifier: " + modifier );
487
510
}
@@ -538,6 +561,74 @@ private MessagePattern[][] getPatternsWithPskModifier(final String modifier) {
538
561
return new MessagePattern [][] { modifiedPreMessagePatterns , modifiedHandshakeMessagePatterns };
539
562
}
540
563
564
+ private MessagePattern [][] getPatternsWithHfsModifier () {
565
+ // Temporarily combine the pre-messages and "normal" messages to make iteration/state management easier
566
+ final MessagePattern [] messagePatterns =
567
+ new MessagePattern [getPreMessagePatterns ().length + getHandshakeMessagePatterns ().length ];
568
+
569
+ System .arraycopy (getPreMessagePatterns (), 0 , messagePatterns , 0 , getPreMessagePatterns ().length );
570
+ System .arraycopy (getHandshakeMessagePatterns (), 0 , messagePatterns ,
571
+ getPreMessagePatterns ().length , getHandshakeMessagePatterns ().length );
572
+
573
+ boolean insertedE1Token = false ;
574
+ boolean insertedEkem1Token = false ;
575
+
576
+ for (int i = 0 ; i < messagePatterns .length ; i ++) {
577
+ if (!insertedE1Token && Arrays .stream (messagePatterns [i ].tokens ()).anyMatch (token -> token == Token .E )) {
578
+ // We haven't inserted an E1 token yet, and this message pattern needs one. Exactly where it should go depends
579
+ // on whether this message pattern also contains a key agreement token, but either way, this pattern will wind
580
+ // up one token longer than it was when it started.
581
+ int insertionIndex = -1 ;
582
+
583
+ for (int t = 0 ; t < messagePatterns [i ].tokens ().length ; t ++) {
584
+ final Token token = messagePatterns [i ].tokens ()[t ];
585
+
586
+ // TODO Prove that E must come before key agreement tokens
587
+ if (token == Token .E || token .isKeyAgreementToken ()) {
588
+ insertionIndex = t + 1 ;
589
+
590
+ if (token .isKeyAgreementToken ()) {
591
+ break ;
592
+ }
593
+ }
594
+ }
595
+
596
+ messagePatterns [i ] = messagePatterns [i ].withAddedToken (Token .E1 , insertionIndex );
597
+ insertedE1Token = true ;
598
+ }
599
+
600
+ if (!insertedEkem1Token && Arrays .stream (messagePatterns [i ].tokens ()).anyMatch (token -> token == Token .EE )) {
601
+ // We haven't inserted an EKEM1 token yet, and this pattern needs one. EKEM1 tokens always go after the first
602
+ // EE token.
603
+ int insertionIndex = -1 ;
604
+
605
+ for (int t = 0 ; t < messagePatterns [i ].tokens ().length ; t ++) {
606
+ if (messagePatterns [i ].tokens ()[t ] == Token .EE ) {
607
+ insertionIndex = t + 1 ;
608
+ break ;
609
+ }
610
+ }
611
+
612
+ messagePatterns [i ] = messagePatterns [i ].withAddedToken (Token .EKEM1 , insertionIndex );
613
+ insertedEkem1Token = true ;
614
+ }
615
+
616
+ if (insertedE1Token && insertedEkem1Token ) {
617
+ // No need to inspect the rest of the message patterns if we've already inserted both of the HFS tokens
618
+ break ;
619
+ }
620
+ }
621
+
622
+ final MessagePattern [] modifiedPreMessagePatterns = new MessagePattern [getPreMessagePatterns ().length ];
623
+ final MessagePattern [] modifiedHandshakeMessagePatterns = new MessagePattern [getHandshakeMessagePatterns ().length ];
624
+
625
+ System .arraycopy (messagePatterns , 0 , modifiedPreMessagePatterns , 0 , getPreMessagePatterns ().length );
626
+ System .arraycopy (messagePatterns , getPreMessagePatterns ().length ,
627
+ modifiedHandshakeMessagePatterns , 0 , getHandshakeMessagePatterns ().length );
628
+
629
+ return new MessagePattern [][] { modifiedPreMessagePatterns , modifiedHandshakeMessagePatterns };
630
+ }
631
+
541
632
private String getModifiedName (final String modifier ) {
542
633
final String modifiedName ;
543
634
0 commit comments