diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 531b545..931d12d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,7 @@ jobs: strategy: matrix: distribution: ['adopt', 'temurin'] - java: [17, 21] + java: [21] fail-fast: false name: JDK ${{ matrix.java }} (${{ matrix.distribution }}) diff --git a/pom.xml b/pom.xml index fd866f7..01e5ca6 100644 --- a/pom.xml +++ b/pom.xml @@ -9,8 +9,8 @@ JGITVER - 17 - 17 + 21 + 21 UTF-8 diff --git a/src/main/java/com/eatthepath/noise/GenerateHandshakeBuilderApp.java b/src/main/java/com/eatthepath/noise/GenerateHandshakeBuilderApp.java index 2fdf8e8..dc5f98e 100644 --- a/src/main/java/com/eatthepath/noise/GenerateHandshakeBuilderApp.java +++ b/src/main/java/com/eatthepath/noise/GenerateHandshakeBuilderApp.java @@ -13,7 +13,8 @@ class GenerateHandshakeBuilderApp { "XK1", "X1K1", "X1X", "XX1", "X1X1", "K1N", "K1K", "KK1", "K1K1", "K1X", "KX1", "K1X1", "I1N", "I1K", "IK1", "I1K1", "I1X", "IX1", "I1X1", "Npsk0", "Kpsk0", "Xpsk1", "NNpsk0", "NNpsk2", "NKpsk0", "NKpsk2", "NXpsk2", "XNpsk3", "XKpsk3", "XXpsk3", "KNpsk0", "KNpsk2", "KKpsk0", "KKpsk2", "KXpsk2", "INpsk1", "INpsk2", "IKpsk1", - "IKpsk2", "IXpsk2" + "IKpsk2", "IXpsk2", "NNhfs", "KNhfs", "NKhfs", "KKhfs", "NXhfs", "KXhfs", "XNhfs", "INhfs", "XKhfs", "IKhfs", + "XXhfs", "IXhfs" }; private static final String INITIALIZER_TEMPLATE = """ diff --git a/src/main/java/com/eatthepath/noise/HandshakePattern.java b/src/main/java/com/eatthepath/noise/HandshakePattern.java index 7c72cbc..c92d363 100644 --- a/src/main/java/com/eatthepath/noise/HandshakePattern.java +++ b/src/main/java/com/eatthepath/noise/HandshakePattern.java @@ -340,6 +340,21 @@ class HandshakePattern { } record MessagePattern(NoiseHandshake.Role sender, Token[] tokens) { + + MessagePattern withAddedToken(final Token token, final int insertionIndex) { + if (insertionIndex < 0 || insertionIndex >= this.tokens().length + 1) { + throw new IllegalArgumentException("Illegal insertion index"); + } + + final Token[] modifiedTokens = new Token[this.tokens().length + 1]; + System.arraycopy(this.tokens(), 0, modifiedTokens, 0, insertionIndex); + modifiedTokens[insertionIndex] = token; + System.arraycopy(this.tokens(), insertionIndex, modifiedTokens, + insertionIndex + 1, this.tokens().length - insertionIndex); + + return new MessagePattern(this.sender(), modifiedTokens); + } + @Override public String toString() { final String prefix = switch (sender()) { @@ -375,18 +390,24 @@ enum Token { ES, SE, SS, - PSK; + PSK, + E1, + EKEM1; static Token fromString(final String string) { - return switch (string) { - case "e", "E" -> E; - case "s", "S" -> S; - case "ee", "EE" -> EE; - case "es", "ES" -> ES; - case "se", "SE" -> SE; - case "ss", "SS" -> SS; - case "psk", "PSK" -> PSK; - default -> throw new IllegalArgumentException("Unrecognized token: " + string); + for (final Token token : Token.values()) { + if (token.name().equalsIgnoreCase(string)) { + return token; + } + } + + throw new IllegalArgumentException("Unrecognized token: " + string); + } + + boolean isKeyAgreementToken() { + return switch (this) { + case EE, ES, SE, SS -> true; + default -> false; }; } } @@ -482,6 +503,8 @@ HandshakePattern withModifier(final String modifier) { modifiedMessagePatterns = getPatternsWithFallbackModifier(); } else if (modifier.startsWith("psk")) { modifiedMessagePatterns = getPatternsWithPskModifier(modifier); + } else if ("hfs".equals(modifier)) { + modifiedMessagePatterns = getPatternsWithHfsModifier(); } else { throw new IllegalArgumentException("Unrecognized modifier: " + modifier); } @@ -538,6 +561,74 @@ private MessagePattern[][] getPatternsWithPskModifier(final String modifier) { return new MessagePattern[][] { modifiedPreMessagePatterns, modifiedHandshakeMessagePatterns }; } + private MessagePattern[][] getPatternsWithHfsModifier() { + // Temporarily combine the pre-messages and "normal" messages to make iteration/state management easier + final MessagePattern[] messagePatterns = + new MessagePattern[getPreMessagePatterns().length + getHandshakeMessagePatterns().length]; + + System.arraycopy(getPreMessagePatterns(), 0, messagePatterns, 0, getPreMessagePatterns().length); + System.arraycopy(getHandshakeMessagePatterns(), 0, messagePatterns, + getPreMessagePatterns().length, getHandshakeMessagePatterns().length); + + boolean insertedE1Token = false; + boolean insertedEkem1Token = false; + + for (int i = 0; i < messagePatterns.length; i++) { + if (!insertedE1Token && Arrays.stream(messagePatterns[i].tokens()).anyMatch(token -> token == Token.E)) { + // We haven't inserted an E1 token yet, and this message pattern needs one. Exactly where it should go depends + // on whether this message pattern also contains a key agreement token, but either way, this pattern will wind + // up one token longer than it was when it started. + int insertionIndex = -1; + + for (int t = 0; t < messagePatterns[i].tokens().length; t++) { + final Token token = messagePatterns[i].tokens()[t]; + + // TODO Prove that E must come before key agreement tokens + if (token == Token.E || token.isKeyAgreementToken()) { + insertionIndex = t + 1; + + if (token.isKeyAgreementToken()) { + break; + } + } + } + + messagePatterns[i] = messagePatterns[i].withAddedToken(Token.E1, insertionIndex); + insertedE1Token = true; + } + + if (!insertedEkem1Token && Arrays.stream(messagePatterns[i].tokens()).anyMatch(token -> token == Token.EE)) { + // We haven't inserted an EKEM1 token yet, and this pattern needs one. EKEM1 tokens always go after the first + // EE token. + int insertionIndex = -1; + + for (int t = 0; t < messagePatterns[i].tokens().length; t++) { + if (messagePatterns[i].tokens()[t] == Token.EE) { + insertionIndex = t + 1; + break; + } + } + + messagePatterns[i] = messagePatterns[i].withAddedToken(Token.EKEM1, insertionIndex); + insertedEkem1Token = true; + } + + if (insertedE1Token && insertedEkem1Token) { + // No need to inspect the rest of the message patterns if we've already inserted both of the HFS tokens + break; + } + } + + final MessagePattern[] modifiedPreMessagePatterns = new MessagePattern[getPreMessagePatterns().length]; + final MessagePattern[] modifiedHandshakeMessagePatterns = new MessagePattern[getHandshakeMessagePatterns().length]; + + System.arraycopy(messagePatterns, 0, modifiedPreMessagePatterns, 0, getPreMessagePatterns().length); + System.arraycopy(messagePatterns, getPreMessagePatterns().length, + modifiedHandshakeMessagePatterns, 0, getHandshakeMessagePatterns().length); + + return new MessagePattern[][] { modifiedPreMessagePatterns, modifiedHandshakeMessagePatterns }; + } + private String getModifiedName(final String modifier) { final String modifiedName; @@ -724,6 +815,10 @@ boolean requiresRemoteStaticPublicKey(final NoiseHandshake.Role role) { .anyMatch(token -> token == Token.S); } + boolean requiresKeyEncapsulationMechanism() { + return getModifiers(getName()).contains("hfs"); + } + @Override public String toString() { final StringBuilder stringBuilder = new StringBuilder(getName() + ":\n"); diff --git a/src/main/java/com/eatthepath/noise/NamedProtocolHandshakeBuilder.java b/src/main/java/com/eatthepath/noise/NamedProtocolHandshakeBuilder.java index 746272e..a5de481 100644 --- a/src/main/java/com/eatthepath/noise/NamedProtocolHandshakeBuilder.java +++ b/src/main/java/com/eatthepath/noise/NamedProtocolHandshakeBuilder.java @@ -205,11 +205,14 @@ public NoiseHandshake build() { keyAgreement, cipher, hash, + null, prologue, localStaticKeyPair, localEphemeralKeyPair, + null, remoteStaticPublicKey, null, + null, preSharedKeys); } } diff --git a/src/main/java/com/eatthepath/noise/NoiseHandshake.java b/src/main/java/com/eatthepath/noise/NoiseHandshake.java index fe56737..a949988 100644 --- a/src/main/java/com/eatthepath/noise/NoiseHandshake.java +++ b/src/main/java/com/eatthepath/noise/NoiseHandshake.java @@ -1,11 +1,10 @@ package com.eatthepath.noise; -import com.eatthepath.noise.component.NoiseCipher; -import com.eatthepath.noise.component.NoiseHash; -import com.eatthepath.noise.component.NoiseKeyAgreement; +import com.eatthepath.noise.component.*; import javax.annotation.Nullable; import javax.crypto.AEADBadTagException; +import javax.crypto.KEM; import javax.crypto.ShortBufferException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -141,6 +140,7 @@ public class NoiseHandshake { private final CipherState cipherState; private final NoiseHash noiseHash; private final NoiseKeyAgreement keyAgreement; + private final NoiseKeyEncapsulationMechanism keyEncapsulationMechanism; private final byte[] chainingKey; private final byte[] hash; @@ -159,6 +159,12 @@ public class NoiseHandshake { @Nullable private PublicKey remoteStaticPublicKey; + @Nullable + private KeyPair localKeyEncapsulationKeyPair; + + @Nullable + private PublicKey remoteKeyEncapsulationPublicKey; + @Nullable private final List preSharedKeys; @@ -189,11 +195,14 @@ public enum Role { final NoiseKeyAgreement keyAgreement, final NoiseCipher noiseCipher, final NoiseHash noiseHash, + @Nullable final NoiseKeyEncapsulationMechanism keyEncapsulationMechanism, @Nullable final byte[] prologue, @Nullable final KeyPair localStaticKeyPair, @Nullable final KeyPair localEphemeralKeyPair, + @Nullable final KeyPair localKeyEncapsulationKeyPair, @Nullable final PublicKey remoteStaticPublicKey, @Nullable final PublicKey remoteEphemeralPublicKey, + @Nullable final PublicKey remoteKeyEncapsulationPublicKey, @Nullable final List preSharedKeys) { this.handshakePattern = handshakePattern; @@ -202,6 +211,7 @@ public enum Role { this.cipherState = new CipherState(noiseCipher); this.noiseHash = noiseHash; this.keyAgreement = keyAgreement; + this.keyEncapsulationMechanism = keyEncapsulationMechanism; if (handshakePattern.requiresLocalStaticKeyPair(role)) { if (localStaticKeyPair == null) { @@ -267,6 +277,10 @@ public enum Role { } } + if (handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism == null) { + throw new IllegalArgumentException(handshakePattern.getName() + " requires a key encapsulation mechanism"); + } + if (localEphemeralKeyPair != null) { try { keyAgreement.checkKeyPair(localEphemeralKeyPair); @@ -279,13 +293,19 @@ public enum Role { this.localStaticKeyPair = localStaticKeyPair; this.localEphemeralKeyPair = localEphemeralKeyPair; + this.localKeyEncapsulationKeyPair = localKeyEncapsulationKeyPair; this.remoteStaticPublicKey = remoteStaticPublicKey; this.remoteEphemeralPublicKey = remoteEphemeralPublicKey; + this.remoteKeyEncapsulationPublicKey = remoteKeyEncapsulationPublicKey; this.preSharedKeys = preSharedKeys; + final String keyAgreementSection = this.keyEncapsulationMechanism == null + ? keyAgreement.getName() + : keyAgreement.getName() + "+" + keyEncapsulationMechanism.getName(); + this.noiseProtocolName = "Noise_" + handshakePattern.getName() + "_" + - keyAgreement.getName() + "_" + + keyAgreementSection + "_" + noiseCipher.getName() + "_" + noiseHash.getName(); @@ -347,7 +367,22 @@ public enum Role { yield staticPublicKey; } - case EE, ES, SE, SS, PSK -> + case E1 -> { + final PublicKey keyEncapsulationPublicKey; + + if (messagePattern.sender() == role) { + keyEncapsulationPublicKey = localKeyEncapsulationKeyPair != null ? localKeyEncapsulationKeyPair.getPublic() : null; + } else { + keyEncapsulationPublicKey = remoteKeyEncapsulationPublicKey; + } + + if (keyEncapsulationPublicKey == null) { + throw new IllegalStateException("Key encapsulation public key for " + messagePattern.sender() + " role must not be null"); + } + + yield keyEncapsulationPublicKey; + } + case EE, ES, SE, SS, PSK, EKEM1 -> throw new IllegalArgumentException("Key-mixing tokens must not appear in pre-messages"); })) .forEach(publicKey -> mixHash(keyAgreement.serializePublicKey(publicKey))); @@ -545,14 +580,11 @@ public int getOutboundMessageLength(final int payloadLength) { throw new IllegalArgumentException("Handshake is not currently expecting to send a message"); } - return getOutboundMessageLength(handshakePattern, currentMessagePattern, keyAgreement.getPublicKeyLength(), payloadLength); + return getOutboundMessageLength(currentMessagePattern, payloadLength); } // Visible for testing - static int getOutboundMessageLength(final HandshakePattern handshakePattern, - final int message, - final int publicKeyLength, - final int payloadLength) { + int getOutboundMessageLength(final int message, final int payloadLength) { if (message < 0 || message >= handshakePattern.getHandshakeMessagePatterns().length) { throw new IndexOutOfBoundsException( @@ -572,6 +604,7 @@ static int getOutboundMessageLength(final HandshakePattern handshakePattern, || token == HandshakePattern.Token.SE || token == HandshakePattern.Token.SS || token == HandshakePattern.Token.PSK + || token == HandshakePattern.Token.EKEM1 || (token == HandshakePattern.Token.E && isPreSharedKeyHandshake)); int messageLength = 0; @@ -579,20 +612,42 @@ static int getOutboundMessageLength(final HandshakePattern handshakePattern, for (final HandshakePattern.Token token : handshakePattern.getHandshakeMessagePatterns()[message].tokens()) { switch (token) { case E -> { - messageLength += publicKeyLength; + messageLength += keyAgreement.getPublicKeyLength(); if (isPreSharedKeyHandshake) { hasKey = true; } } + case S -> { - messageLength += publicKeyLength; + messageLength += keyAgreement.getPublicKeyLength(); if (hasKey) { // If we have a key, then the static key is encrypted and has a 16-byte AEAD tag messageLength += 16; } } + + case E1 -> { + messageLength += keyEncapsulationMechanism.getPublicKeyLength(); + + if (hasKey) { + // If we have a key, then the key encapsulation public key is encrypted and has a 16-byte AEAD tag + messageLength += 16; + } + } + + case EKEM1 -> { + messageLength += keyEncapsulationMechanism.getEncapsulationLength(); + + if (hasKey) { + // If we have a key, then the key encapsulation is encrypted and has a 16-byte AEAD tag + messageLength += 16; + } + + hasKey = true; + } + case EE, ES, SE, SS, PSK -> hasKey = true; } } @@ -624,15 +679,11 @@ public int getPayloadLength(final int handshakeMessageLength) { throw new IllegalStateException("Handshake is not currently expecting to read a message"); } - return getPayloadLength(handshakePattern, currentMessagePattern, keyAgreement.getPublicKeyLength(), handshakeMessageLength); + return getPayloadLength(currentMessagePattern, handshakeMessageLength); } - static int getPayloadLength(final HandshakePattern handshakePattern, - final int message, - final int publicKeyLength, - final int ciphertextLength) { - - final int emptyPayloadMessageLength = getOutboundMessageLength(handshakePattern, message, publicKeyLength, 0); + int getPayloadLength(final int message, final int ciphertextLength) { + final int emptyPayloadMessageLength = getOutboundMessageLength(message, 0); if (ciphertextLength < emptyPayloadMessageLength) { throw new IllegalArgumentException("Ciphertext is shorter than minimum expected message length"); @@ -762,6 +813,44 @@ public int writeMessage(@Nullable final byte[] payload, } } + case E1 -> { + localKeyEncapsulationKeyPair = keyEncapsulationMechanism.generateKeyPair(); + + try { + offset += encryptAndHash( + keyEncapsulationMechanism.serializePublicKey(localKeyEncapsulationKeyPair.getPublic()), + 0, keyEncapsulationMechanism.getPublicKeyLength(), message, offset); + } catch (final ShortBufferException e) { + // This should never happen for buffers we control + throw new AssertionError("Short buffer for key encapsulation public key component", e); + } + } + + case EKEM1 -> { + if (localKeyEncapsulationKeyPair != null) { + throw new IllegalStateException("Local key encapsulation key pair already set"); + } + + if (remoteKeyEncapsulationPublicKey == null) { + throw new IllegalStateException("No remote key encapsulation public key available"); + } + + localKeyEncapsulationKeyPair = keyEncapsulationMechanism.generateKeyPair(); + + final KEM.Encapsulated encapsulated = + keyEncapsulationMechanism.encapsulate(remoteKeyEncapsulationPublicKey); + + try { + offset += encryptAndHash(encapsulated.encapsulation(), + 0, keyEncapsulationMechanism.getEncapsulationLength(), message, offset); + } catch (final ShortBufferException e) { + // This should never happen for buffers we control + throw new AssertionError("Short buffer for key encapsulation component", e); + } + + mixKey(keyEncapsulationMechanism.serializeSharedSecret(encapsulated.key())); + } + case EE, ES, SE, SS, PSK -> handleMixKeyToken(token); } } @@ -896,6 +985,43 @@ public int writeMessage(@Nullable final ByteBuffer payload, } } + case E1 -> { + localKeyEncapsulationKeyPair = keyEncapsulationMechanism.generateKeyPair(); + + try { + bytesWritten += encryptAndHash( + ByteBuffer.wrap(keyEncapsulationMechanism.serializePublicKey(localKeyEncapsulationKeyPair.getPublic())), + message); + } catch (final ShortBufferException e) { + // This should never happen for buffers we control + throw new AssertionError("Short buffer for key encapsulation public key component", e); + } + } + + case EKEM1 -> { + if (localKeyEncapsulationKeyPair != null) { + throw new IllegalStateException("Local key encapsulation key pair already set"); + } + + if (remoteKeyEncapsulationPublicKey == null) { + throw new IllegalStateException("No remote key encapsulation public key available"); + } + + localKeyEncapsulationKeyPair = keyEncapsulationMechanism.generateKeyPair(); + + final KEM.Encapsulated encapsulated = + keyEncapsulationMechanism.encapsulate(remoteKeyEncapsulationPublicKey); + + try { + bytesWritten += encryptAndHash(ByteBuffer.wrap(encapsulated.encapsulation()), message); + } catch (final ShortBufferException e) { + // This should never happen for buffers we control + throw new AssertionError("Short buffer for key encapsulation component", e); + } + + mixKey(keyEncapsulationMechanism.serializeSharedSecret(encapsulated.key())); + } + case EE, ES, SE, SS, PSK -> handleMixKeyToken(token); } } @@ -1020,6 +1146,40 @@ public int readMessage(final byte[] message, offset += staticKeyCiphertextLength; } + case E1 -> { + if (remoteKeyEncapsulationPublicKey != null) { + throw new IllegalStateException("Remote key encapsulation public key already set"); + } + + final int keyEncapsulationPublicKeyCiphertextLength = + keyEncapsulationMechanism.getPublicKeyLength() + (cipherState.hasKey() ? 16 : 0); + + final byte[] keyEncapsulationPublicKeyBytes = new byte[keyEncapsulationMechanism.getPublicKeyLength()]; + + decryptAndHash(message, offset, keyEncapsulationPublicKeyCiphertextLength, keyEncapsulationPublicKeyBytes, 0); + + remoteKeyEncapsulationPublicKey = + keyEncapsulationMechanism.deserializePublicKey(keyEncapsulationPublicKeyBytes); + + offset += keyEncapsulationPublicKeyCiphertextLength; + } + + case EKEM1 -> { + if (localKeyEncapsulationKeyPair == null) { + throw new IllegalStateException("Local key encapsulation key not set"); + } + + final int keyEncapsulationLength = + keyEncapsulationMechanism.getEncapsulationLength() + (cipherState.hasKey() ? 16 : 0); + + final byte[] keyEncapsulationBytes = new byte[keyEncapsulationMechanism.getEncapsulationLength()]; + decryptAndHash(message, offset, keyEncapsulationLength, keyEncapsulationBytes, 0); + + mixKey(keyEncapsulationMechanism.decapsulate(localKeyEncapsulationKeyPair.getPrivate(), keyEncapsulationBytes)); + + offset += keyEncapsulationLength; + } + case EE, ES, SE, SS, PSK -> handleMixKeyToken(token); } } @@ -1132,6 +1292,47 @@ public int readMessage(final ByteBuffer message, remoteStaticPublicKey = keyAgreement.deserializePublicKey(staticKeyBytes); } + case E1 -> { + if (remoteKeyEncapsulationPublicKey != null) { + throw new IllegalStateException("Remote key encapsulation public key already set"); + } + + final int keyEncapsulationPublicKeyCiphertextLength = + keyEncapsulationMechanism.getPublicKeyLength() + (cipherState.hasKey() ? 16 : 0); + + final byte[] keyEncapsulationPublicKeyBytes = new byte[keyEncapsulationMechanism.getPublicKeyLength()]; + + final ByteBuffer keyEncapsulationPublicKeyCiphertextSlice = + message.slice(message.position(), keyEncapsulationPublicKeyCiphertextLength); + + decryptAndHash(keyEncapsulationPublicKeyCiphertextSlice, ByteBuffer.wrap(keyEncapsulationPublicKeyBytes)); + + // Operating on a slice doesn't advance the main buffer's position; do so manually instead + message.position(message.position() + keyEncapsulationPublicKeyCiphertextLength); + + remoteKeyEncapsulationPublicKey = + keyEncapsulationMechanism.deserializePublicKey(keyEncapsulationPublicKeyBytes); + } + + case EKEM1 -> { + if (localKeyEncapsulationKeyPair == null) { + throw new IllegalStateException("Local key encapsulation key not set"); + } + + final int keyEncapsulationLength = + keyEncapsulationMechanism.getEncapsulationLength() + (cipherState.hasKey() ? 16 : 0); + + final byte[] keyEncapsulationBytes = new byte[keyEncapsulationMechanism.getEncapsulationLength()]; + + final ByteBuffer keyEncapsulationCiphertextSlice = message.slice(message.position(), keyEncapsulationLength); + decryptAndHash(keyEncapsulationCiphertextSlice, ByteBuffer.wrap(keyEncapsulationBytes)); + + // Operating on a slice doesn't advance the main buffer's position; do so manually instead + message.position(message.position() + keyEncapsulationLength); + + mixKey(keyEncapsulationMechanism.decapsulate(localKeyEncapsulationKeyPair.getPrivate(), keyEncapsulationBytes)); + } + case EE, ES, SE, SS, PSK -> handleMixKeyToken(token); } } @@ -1331,16 +1532,20 @@ public NoiseHandshake fallbackTo(final String handshakePatternName, @Nullable fi hasFallenBack = true; + // TODO Add support for fallbacks with HFS patterns return new NoiseHandshake(role, fallbackPattern, keyAgreement, cipherState.getCipher(), noiseHash, + null, prologue, fallbackLocalStaticKeyPair, localEphemeralKeyPair, + null, fallbackRemoteStaticPublicKey, fallbackRemoteEphemeralPublicKey, + null, preSharedKeys); } diff --git a/src/main/java/com/eatthepath/noise/NoiseHandshakeBuilder.java b/src/main/java/com/eatthepath/noise/NoiseHandshakeBuilder.java index 4134c32..85fdea3 100644 --- a/src/main/java/com/eatthepath/noise/NoiseHandshakeBuilder.java +++ b/src/main/java/com/eatthepath/noise/NoiseHandshakeBuilder.java @@ -40,6 +40,7 @@ public class NoiseHandshakeBuilder { @Nullable private NoiseCipher cipher; @Nullable private NoiseHash hash; @Nullable private NoiseKeyAgreement keyAgreement; + @Nullable private NoiseKeyEncapsulationMechanism keyEncapsulationMechanism; private NoiseHandshakeBuilder(final NoiseHandshake.Role role, final HandshakePattern handshakePattern, @@ -141,9 +142,19 @@ public NoiseHandshakeBuilder setHash(final String hashName) throws NoSuchAlgorit } /** - * Sets the key agreement algorithm to be used by this handshake. + *

Sets the key agreement algorithm to be used by this handshake. Note that key agreement names may be coupled with + * a key encapsulation mechanism name as described in Section 5 of "KEM-based Hybrid Forward Secrecy for Noise," which + * says, in part:

* - * @param keyAgreementName the name of the Noise key agreement to be used by this handshake + *
When the "hfs" modifier is used, the DH name section must contain a KEM algorithm name directly + * following the DH algorithm name, separated by a plus sign.
+ * + *

Calling this method with a KEM algorithm name in the {@code keyAgreementName} argument is the equivalent of + * calling this method with the bare key agreement algorithm name, then calling + * {@link #setKeyEncapsulationMechanism(String)} with the KEM algorithm name.

+ * + * @param keyAgreementName the name of the Noise key agreement (possibly including a KEM algorithm) to be used by + * this handshake * * @return a reference to this handshake builder * @@ -151,9 +162,44 @@ public NoiseHandshakeBuilder setHash(final String hashName) throws NoSuchAlgorit * @throws IllegalArgumentException if the given name is not recognized as a Noise key agreement algorithm name * * @see NoiseCipher#getInstance(String) + * @see KEM-based Hybrid Forward Secrecy for Noise, Section 5: The "hfs" modifier */ public NoiseHandshakeBuilder setKeyAgreement(final String keyAgreementName) throws NoSuchAlgorithmException { - this.keyAgreement = NoiseKeyAgreement.getInstance(Objects.requireNonNull(keyAgreementName, "Key agreement algorithm must not be null")); + final int separatorIndex = Objects.requireNonNull(keyAgreementName, "Key agreement algorithm must not be null") + .indexOf('+'); + + final String keyAgreementComponent; + + if (separatorIndex == -1) { + keyAgreementComponent = keyAgreementName; + } else { + keyAgreementComponent = keyAgreementName.substring(0, separatorIndex); + final String keyEncapsulationMechanismComponent = keyAgreementName.substring(separatorIndex); + + this.setKeyEncapsulationMechanism(keyEncapsulationMechanismComponent); + } + + this.keyAgreement = NoiseKeyAgreement.getInstance(keyAgreementComponent); + + return this; + } + + /** + * Sets the key encapsulation mechanism to be used by this handshake. + * + * @param keyEncapsulationMechanismName the name of the Noise key encapsulation mechanism to be used by this handshake + * + * @return a reference ot this handshake builder + * + * @throws NoSuchAlgorithmException if the named algorithm is not supported by the current JVM + * @throws IllegalArgumentException if the given name is not recognized as a Noise key encapsulation mechanism + * + * @see NoiseKeyEncapsulationMechanism#getInstance(String) + */ + public NoiseHandshakeBuilder setKeyEncapsulationMechanism(final String keyEncapsulationMechanismName) throws NoSuchAlgorithmException { + this.keyEncapsulationMechanism = NoiseKeyEncapsulationMechanism.getInstance( + Objects.requireNonNull(keyEncapsulationMechanismName, "Key encapsulation mechanism name must not be null")); + return this; } @@ -183,16 +229,25 @@ public NoiseHandshake build() { throw new IllegalArgumentException("Must set a key agreement algorithm before building a Noise handshake"); } + if (handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism == null) { + throw new IllegalArgumentException("Must set a key encapsulation mechanism before building a Noise handshake with an HFS pattern"); + } else if (!handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism != null) { + throw new IllegalArgumentException("Must not specify a key encapsulation mechanism for a non-HFS handshake pattern"); + } + return new NoiseHandshake(role, handshakePattern, keyAgreement, cipher, hash, + keyEncapsulationMechanism, prologue, localStaticKeyPair, null, + null, remoteStaticPublicKey, null, + null, preSharedKey != null ? List.of(preSharedKey) : null); } @@ -3071,4 +3126,580 @@ public static NoiseHandshakeBuilder forIXPsk2Responder(final KeyPair localStatic throw new AssertionError("Statically-generated handshake pattern not found", e); } } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * NNhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forNNHfsInitiator() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("NNhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * NNhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forNNHfsResponder() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("NNhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in a + * KNhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKNHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("KNhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in a + * KNhfs handshake. + * + * + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKNHfsResponder(final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("KNhfs"), + null, + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * NKhfs handshake. + * + * + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forNKHfsInitiator(final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("NKhfs"), + null, + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * NKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forNKHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("NKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in a + * KKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKKHfsInitiator(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("KKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in a + * KKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKKHfsResponder(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("KKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * NXhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forNXHfsInitiator() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("NXhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * NXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forNXHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("NXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in a + * KXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKXHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("KXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in a + * KXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKXHfsResponder(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("KXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * XNhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXNHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("XNhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * XNhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forXNHfsResponder() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("XNhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * INhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forINHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("INhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * INhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forINHfsResponder() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("INhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * XKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXKHfsInitiator(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("XKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * XKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXKHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("XKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * IKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forIKHfsInitiator(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("IKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * IKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forIKHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("IKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * XXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXXHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("XXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * XXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXXHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("XXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * IXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forIXHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("IXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * IXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forIXHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("IXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } } diff --git a/src/main/java/com/eatthepath/noise/component/AbstractXECKeyAgreement.java b/src/main/java/com/eatthepath/noise/component/AbstractXECKeyAgreement.java index 254352c..169b64e 100644 --- a/src/main/java/com/eatthepath/noise/component/AbstractXECKeyAgreement.java +++ b/src/main/java/com/eatthepath/noise/component/AbstractXECKeyAgreement.java @@ -42,57 +42,22 @@ public byte[] generateSecret(final PrivateKey privateKey, final PublicKey public @Override public byte[] serializePublicKey(final PublicKey publicKey) { - // This is a little hacky, but the structure for an X.509 public key defines the order in which its elements appear. - // The first part of the key, which defines the algorithm and its parameters, is always the same for keys of the - // same type, and the last N bytes are the literal key material. - final byte[] serializedPublicKey = new byte[getPublicKeyLength()]; - System.arraycopy(publicKey.getEncoded(), getX509Prefix().length, serializedPublicKey, 0, getPublicKeyLength()); - - return serializedPublicKey; + return XECUtil.serializePublicKey(publicKey, getPublicKeyLength(), getX509Prefix()); } @Override public PublicKey deserializePublicKey(final byte[] publicKeyBytes) { - final int publicKeyLength = getPublicKeyLength(); - - if (publicKeyBytes.length != publicKeyLength) { - throw new IllegalArgumentException("Unexpected serialized public key length"); - } - - final byte[] x509Prefix = getX509Prefix(); - final byte[] x509Bytes = new byte[publicKeyLength + x509Prefix.length]; - System.arraycopy(x509Prefix, 0, x509Bytes, 0, x509Prefix.length); - System.arraycopy(publicKeyBytes, 0, x509Bytes, x509Prefix.length, publicKeyLength); - - try { - return keyFactory.generatePublic(new X509EncodedKeySpec(x509Bytes, keyFactory.getAlgorithm())); - } catch (final InvalidKeySpecException e) { - throw new IllegalArgumentException("Invalid key", e); - } + return XECUtil.deserializePublicKey(publicKeyBytes, getPublicKeyLength(), getX509Prefix(), keyFactory); } @Override public void checkPublicKey(final PublicKey publicKey) throws InvalidKeyException { - checkKey(publicKey); + XECUtil.checkKey(publicKey, keyAgreement.getAlgorithm()); } @Override public void checkKeyPair(final KeyPair keyPair) throws InvalidKeyException { - checkKey(keyPair.getPublic()); - checkKey(keyPair.getPrivate()); - } - - private void checkKey(final Key key) throws InvalidKeyException { - if (key instanceof XECKey xecKey) { - if (xecKey.getParams() instanceof NamedParameterSpec namedParameterSpec) { - if (!keyAgreement.getAlgorithm().equals(namedParameterSpec.getName())) { - throw new InvalidKeyException("Unexpected key algorithm: " + namedParameterSpec.getName()); - } - } else { - throw new InvalidKeyException("Unexpected key parameter type: " + xecKey.getParams().getClass()); - } - } else { - throw new InvalidKeyException("Unexpected key type: " + key.getClass()); - } + XECUtil.checkKey(keyPair.getPublic(), keyAgreement.getAlgorithm()); + XECUtil.checkKey(keyPair.getPrivate(), keyAgreement.getAlgorithm()); } } diff --git a/src/main/java/com/eatthepath/noise/component/DhkemKeyEncapsulationMechanism.java b/src/main/java/com/eatthepath/noise/component/DhkemKeyEncapsulationMechanism.java new file mode 100644 index 0000000..57ff3ec --- /dev/null +++ b/src/main/java/com/eatthepath/noise/component/DhkemKeyEncapsulationMechanism.java @@ -0,0 +1,81 @@ +package com.eatthepath.noise.component; + +import javax.crypto.DecapsulateException; +import javax.crypto.KEM; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.security.*; + +public class DhkemKeyEncapsulationMechanism implements NoiseKeyEncapsulationMechanism { + + private final KEM kem; + private final KeyPairGenerator keyPairGenerator; + private final KeyFactory keyFactory; + + public DhkemKeyEncapsulationMechanism() throws NoSuchAlgorithmException { + this.keyPairGenerator = KeyPairGenerator.getInstance("X25519"); + this.keyFactory = KeyFactory.getInstance("X25519"); + this.kem = KEM.getInstance("DHKEM"); + } + + @Override + public String getName() { + return "DHKEM"; + } + + @Override + public KeyPair generateKeyPair() { + return keyPairGenerator.generateKeyPair(); + } + + @Override + public KEM.Encapsulated encapsulate(final PublicKey publicKey) { + try { + return kem.newEncapsulator(publicKey).encapsulate(); + } catch (final InvalidKeyException e) { + throw new IllegalArgumentException("Invalid public key for encapsulation", e); + } + } + + @Override + public byte[] decapsulate(final PrivateKey privateKey, final byte[] encapsulation) { + try { + return serializeSharedSecret(kem.newDecapsulator(privateKey).decapsulate(encapsulation)); + } catch (final DecapsulateException e) { + throw new IllegalArgumentException("Invalid encapsulation", e); + } catch (final InvalidKeyException e) { + throw new IllegalArgumentException("Invalid private key for decapsulation", e); + } + } + + @Override + public int getPublicKeyLength() { + return 32; + } + + @Override + public int getEncapsulationLength() { + return 32; + } + + @Override + public byte[] serializePublicKey(final PublicKey publicKey) { + return XECUtil.serializePublicKey(publicKey, getPublicKeyLength(), XECUtil.X25519_X509_PREFIX); + } + + @Override + public PublicKey deserializePublicKey(final byte[] publicKeyBytes) { + return XECUtil.deserializePublicKey(publicKeyBytes, getPublicKeyLength(), XECUtil.X25519_X509_PREFIX, keyFactory); + } + + @Override + public byte[] serializeSharedSecret(final SecretKey sharedSecret) { + // For DHKEM, the shared secret has a "raw" encoding + return sharedSecret.getEncoded(); + } + + @Override + public SecretKey deserializeSharedSecret(final byte[] sharedSecretBytes) { + return new SecretKeySpec(sharedSecretBytes, "Generic"); + } +} diff --git a/src/main/java/com/eatthepath/noise/component/NoiseKeyEncapsulationMechanism.java b/src/main/java/com/eatthepath/noise/component/NoiseKeyEncapsulationMechanism.java new file mode 100644 index 0000000..415f902 --- /dev/null +++ b/src/main/java/com/eatthepath/noise/component/NoiseKeyEncapsulationMechanism.java @@ -0,0 +1,45 @@ +package com.eatthepath.noise.component; + +import javax.crypto.KEM; +import javax.crypto.KEM; +import javax.crypto.SecretKey; +import java.security.KeyPair; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.PublicKey; + +public interface NoiseKeyEncapsulationMechanism { + + static NoiseKeyEncapsulationMechanism getInstance(final String name) throws NoSuchAlgorithmException { + return switch (name) { + case "DHKEM" -> new DhkemKeyEncapsulationMechanism(); + default -> throw new IllegalArgumentException("Unrecognized key encapsulation method name: " + name); + }; + } + + String getName(); + + KeyPair generateKeyPair(); + + /** + * + * @param publicKey the remote public key with which to encapsulate a shared secret + * + * @return an encapsulated shared secret key + */ + KEM.Encapsulated encapsulate(PublicKey publicKey); + + byte[] decapsulate(PrivateKey privateKey, byte[] encapsulation); + + int getPublicKeyLength(); + + int getEncapsulationLength(); + + byte[] serializePublicKey(PublicKey publicKey); + + PublicKey deserializePublicKey(byte[] publicKeyBytes); + + byte[] serializeSharedSecret(SecretKey sharedSecret); + + SecretKey deserializeSharedSecret(byte[] sharedSecretBytes); +} diff --git a/src/main/java/com/eatthepath/noise/component/X25519KeyAgreement.java b/src/main/java/com/eatthepath/noise/component/X25519KeyAgreement.java index 4690282..b725631 100644 --- a/src/main/java/com/eatthepath/noise/component/X25519KeyAgreement.java +++ b/src/main/java/com/eatthepath/noise/component/X25519KeyAgreement.java @@ -7,7 +7,6 @@ class X25519KeyAgreement extends AbstractXECKeyAgreement { private static final String ALGORITHM = "X25519"; - private static final byte[] X509_PREFIX = HexFormat.of().parseHex("302a300506032b656e032100"); public X25519KeyAgreement() throws NoSuchAlgorithmException { super(KeyAgreement.getInstance(ALGORITHM), KeyPairGenerator.getInstance(ALGORITHM), KeyFactory.getInstance(ALGORITHM)); @@ -25,6 +24,6 @@ public int getPublicKeyLength() { @Override protected byte[] getX509Prefix() { - return X509_PREFIX; + return XECUtil.X25519_X509_PREFIX; } } diff --git a/src/main/java/com/eatthepath/noise/component/X448KeyAgreement.java b/src/main/java/com/eatthepath/noise/component/X448KeyAgreement.java index e44d36e..3a281a9 100644 --- a/src/main/java/com/eatthepath/noise/component/X448KeyAgreement.java +++ b/src/main/java/com/eatthepath/noise/component/X448KeyAgreement.java @@ -7,7 +7,6 @@ class X448KeyAgreement extends AbstractXECKeyAgreement { private static final String ALGORITHM = "X448"; - private static final byte[] X509_PREFIX = HexFormat.of().parseHex("3042300506032b656f033900"); public X448KeyAgreement() throws NoSuchAlgorithmException { super(KeyAgreement.getInstance(ALGORITHM), KeyPairGenerator.getInstance(ALGORITHM), KeyFactory.getInstance(ALGORITHM)); @@ -25,6 +24,6 @@ public int getPublicKeyLength() { @Override protected byte[] getX509Prefix() { - return X509_PREFIX; + return XECUtil.X448_X509_PREFIX; } } diff --git a/src/main/java/com/eatthepath/noise/component/XECUtil.java b/src/main/java/com/eatthepath/noise/component/XECUtil.java new file mode 100644 index 0000000..ff9c00f --- /dev/null +++ b/src/main/java/com/eatthepath/noise/component/XECUtil.java @@ -0,0 +1,64 @@ +package com.eatthepath.noise.component; + +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.KeyFactory; +import java.security.PublicKey; +import java.security.interfaces.XECKey; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.NamedParameterSpec; +import java.security.spec.X509EncodedKeySpec; +import java.util.HexFormat; + +class XECUtil { + + static final byte[] X25519_X509_PREFIX = HexFormat.of().parseHex("302a300506032b656e032100"); + static final byte[] X448_X509_PREFIX = HexFormat.of().parseHex("3042300506032b656f033900"); + + private XECUtil() { + } + + static byte[] serializePublicKey(final PublicKey publicKey, final int publicKeyLength, final byte[] x509Prefix) { + // This is a little hacky, but the structure for an X.509 public key defines the order in which its elements appear. + // The first part of the key, which defines the algorithm and its parameters, is always the same for keys of the + // same type, and the last N bytes are the literal key material. + final byte[] serializedPublicKey = new byte[publicKeyLength]; + System.arraycopy(publicKey.getEncoded(), x509Prefix.length, serializedPublicKey, 0, publicKeyLength); + + return serializedPublicKey; + } + + static PublicKey deserializePublicKey(final byte[] publicKeyBytes, + final int publicKeyLength, + final byte[] x509Prefix, + final KeyFactory keyFactory) { + + if (publicKeyBytes.length != publicKeyLength) { + throw new IllegalArgumentException("Unexpected serialized public key length"); + } + + final byte[] x509Bytes = new byte[publicKeyLength + x509Prefix.length]; + System.arraycopy(x509Prefix, 0, x509Bytes, 0, x509Prefix.length); + System.arraycopy(publicKeyBytes, 0, x509Bytes, x509Prefix.length, publicKeyLength); + + try { + return keyFactory.generatePublic(new X509EncodedKeySpec(x509Bytes, keyFactory.getAlgorithm())); + } catch (final InvalidKeySpecException e) { + throw new IllegalArgumentException("Invalid key", e); + } + } + + static void checkKey(final Key key, final String algorithm) throws InvalidKeyException { + if (key instanceof XECKey xecKey) { + if (xecKey.getParams() instanceof NamedParameterSpec namedParameterSpec) { + if (!algorithm.equals(namedParameterSpec.getName())) { + throw new InvalidKeyException("Unexpected key algorithm: " + namedParameterSpec.getName()); + } + } else { + throw new InvalidKeyException("Unexpected key parameter type: " + xecKey.getParams().getClass()); + } + } else { + throw new InvalidKeyException("Unexpected key type: " + key.getClass()); + } + } +} diff --git a/src/main/resources/com/eatthepath/noise/NoiseHandshakeBuilder.java.template b/src/main/resources/com/eatthepath/noise/NoiseHandshakeBuilder.java.template index 8ad9782..3cd94c3 100644 --- a/src/main/resources/com/eatthepath/noise/NoiseHandshakeBuilder.java.template +++ b/src/main/resources/com/eatthepath/noise/NoiseHandshakeBuilder.java.template @@ -10,18 +10,20 @@ import java.util.List; import java.util.Objects; /** - * A Noise handshake builder constructs {@link NoiseHandshake} instances with known handshake patterns and roles. + *

A Noise handshake builder constructs {@link NoiseHandshake} instances with known handshake patterns and roles. * In contrast to {@link NamedProtocolHandshakeBuilder}, this builder provides compile-time checks that all required * keys are provided, but places the burden of selecting protocol components (key agreement algorithms, ciphers, and - * hash algorithms) on the caller. - *

- * Callers may specify the cryptographic components of a Noise protocol by providing a full Noise protocol name… - *

+ * hash algorithms) on the caller.

+ * + *

Callers may specify the cryptographic components of a Noise protocol by providing a full Noise protocol name…

+ * * {@snippet file="NoiseHandshakeBuilderExample.java" region="ik-handshake-protocol-name"} - *

- * …or by specifying the name of each component individually: - *

+ * + *

…or by specifying the name of each component individually:

+ * * {@snippet file="NoiseHandshakeBuilderExample.java" region="ik-handshake-component-names"} + * + * @see NamedProtocolHandshakeBuilder */ @SuppressWarnings("unused") public class NoiseHandshakeBuilder { @@ -38,6 +40,7 @@ public class NoiseHandshakeBuilder { @Nullable private NoiseCipher cipher; @Nullable private NoiseHash hash; @Nullable private NoiseKeyAgreement keyAgreement; + @Nullable private NoiseKeyEncapsulationMechanism keyEncapsulationMechanism; private NoiseHandshakeBuilder(final NoiseHandshake.Role role, final HandshakePattern handshakePattern, @@ -139,9 +142,19 @@ public class NoiseHandshakeBuilder { } /** - * Sets the key agreement algorithm to be used by this handshake. + *

Sets the key agreement algorithm to be used by this handshake. Note that key agreement names may be coupled with + * a key encapsulation mechanism name as described in Section 5 of "KEM-based Hybrid Forward Secrecy for Noise," which + * says, in part:

* - * @param keyAgreementName the name of the Noise key agreement to be used by this handshake + *
When the "hfs" modifier is used, the DH name section must contain a KEM algorithm name directly + * following the DH algorithm name, separated by a plus sign.
+ * + *

Calling this method with a KEM algorithm name in the {@code keyAgreementName} argument is the equivalent of + * calling this method with the bare key agreement algorithm name, then calling + * {@link #setKeyEncapsulationMechanism(String)} (String)} with the KEM algorithm name.

+ * + * @param keyAgreementName the name of the Noise key agreement (possibly including a KEM algorithm) to be used by + * this handshake * * @return a reference to this handshake builder * @@ -149,9 +162,44 @@ public class NoiseHandshakeBuilder { * @throws IllegalArgumentException if the given name is not recognized as a Noise key agreement algorithm name * * @see NoiseCipher#getInstance(String) + * @see KEM-based Hybrid Forward Secrecy for Noise, Section 5: The "hfs" modifier */ public NoiseHandshakeBuilder setKeyAgreement(final String keyAgreementName) throws NoSuchAlgorithmException { - this.keyAgreement = NoiseKeyAgreement.getInstance(Objects.requireNonNull(keyAgreementName, "Key agreement algorithm must not be null")); + final int separatorIndex = Objects.requireNonNull(keyAgreementName, "Key agreement algorithm must not be null") + .indexOf('+'); + + final String keyAgreementComponent; + + if (separatorIndex == -1) { + keyAgreementComponent = keyAgreementName; + } else { + keyAgreementComponent = keyAgreementName.substring(0, separatorIndex); + final String keyEncapsulationMechanismComponent = keyAgreementName.substring(separatorIndex); + + this.setKeyEncapsulationMechanism(keyEncapsulationMechanismComponent); + } + + this.keyAgreement = NoiseKeyAgreement.getInstance(keyAgreementComponent); + + return this; + } + + /** + * Sets the key encapsulation mechanism to be used by this handshake. + * + * @param keyEncapsulationMechanismName the name of the Noise key encapsulation mechanism to be used by this handshake + * + * @return a reference ot this handshake builder + * + * @throws NoSuchAlgorithmException if the named algorithm is not supported by the current JVM + * @throws IllegalArgumentException if the given name is not recognized as a Noise key encapsulation mechanism + * + * @see NoiseKeyEncapsulationMechanism#getInstance(String) + */ + public NoiseHandshakeBuilder setKeyEncapsulationMechanism(final String keyEncapsulationMechanismName) throws NoSuchAlgorithmException { + this.keyEncapsulationMechanism = NoiseKeyEncapsulationMechanism.getInstance( + Objects.requireNonNull(keyEncapsulationMechanismName, "Key encapsulation mechanism name must not be null")); + return this; } @@ -181,16 +229,25 @@ public class NoiseHandshakeBuilder { throw new IllegalArgumentException("Must set a key agreement algorithm before building a Noise handshake"); } + if (handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism == null) { + throw new IllegalArgumentException("Must set a key encapsulation mechanism before building a Noise handshake with an HFS pattern"); + } else if (!handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism != null) { + throw new IllegalArgumentException("Must not specify a key encapsulation mechanism for a non-HFS handshake pattern"); + } + return new NoiseHandshake(role, handshakePattern, keyAgreement, cipher, hash, + keyEncapsulationMechanism, prologue, localStaticKeyPair, null, + null, remoteStaticPublicKey, null, + null, preSharedKey != null ? List.of(preSharedKey) : null); } diff --git a/src/test/java/com/eatthepath/noise/HandshakePatternTest.java b/src/test/java/com/eatthepath/noise/HandshakePatternTest.java index 24f7284..d4ee2b0 100644 --- a/src/test/java/com/eatthepath/noise/HandshakePatternTest.java +++ b/src/test/java/com/eatthepath/noise/HandshakePatternTest.java @@ -169,6 +169,111 @@ void withPskModifier() throws NoSuchPatternException { } } + @ParameterizedTest + @MethodSource + void withHfsModifier(final HandshakePattern expectedHfsPattern) throws NoSuchPatternException { + final String fundamentalPatternName = HandshakePattern.getFundamentalPatternName(expectedHfsPattern.getName()); + + assertEquals(expectedHfsPattern, HandshakePattern.getInstance(fundamentalPatternName).withModifier("hfs")); + } + + private static List withHfsModifier() { + return List.of( + HandshakePattern.fromString(""" + NNhfs: + -> e, e1 + <- e, ee, ekem1 + """), + + HandshakePattern.fromString(""" + NKhfs: + <- s + ... + -> e, es, e1 + <- e, ee, ekem1 + """), + + HandshakePattern.fromString(""" + NXhfs: + -> e, e1 + <- e, ee, ekem1, s, es + """), + + HandshakePattern.fromString(""" + XNhfs: + -> e, e1 + <- e, ee, ekem1 + -> s, se + """), + + HandshakePattern.fromString(""" + XKhfs: + <- s + ... + -> e, es, e1 + <- e, ee, ekem1 + -> s, se + """), + + HandshakePattern.fromString(""" + XXhfs: + -> e, e1 + <- e, ee, ekem1, s, es + -> s, se + """), + + HandshakePattern.fromString(""" + KNhfs: + -> s + ... + -> e, e1 + <- e, ee, ekem1, se + """), + + // Note that this is different from what's listed at https://github.com/noiseprotocol/noise_hfs_spec/blob/025f0f60cb3b94ad75b68e3a4158b9aac234f8cb/noise_hfs.md?plain=1#L130-L135; + // the specification (at the time of writing) appears to have a typo. Please see + // https://github.com/noiseprotocol/noise_hfs_spec/pull/3. + HandshakePattern.fromString(""" + KKhfs: + -> s + <- s + ... + -> e, es, e1, ss + <- e, ee, ekem1, se + """), + + HandshakePattern.fromString(""" + KXhfs: + -> s + ... + -> e, e1 + <- e, ee, ekem1, se, s, es + """), + + // This also deviates from the latest version of the spec to fix a typo (the `ee` token is missing in the + // current draft of the spec). Please see https://github.com/noiseprotocol/noise_hfs_spec/pull/4. + HandshakePattern.fromString(""" + INhfs: + -> e, e1, s + <- e, ee, ekem1, se + """), + + HandshakePattern.fromString(""" + IKhfs: + <- s + ... + -> e, es, e1, s, ss + <- e, ee, ekem1, se + """), + + HandshakePattern.fromString(""" + IXhfs: + -> e, e1, s + <- e, ee, ekem1, se, s, es + """) + ); + } + @Test void withModifierUnrecognized() { assertThrows(IllegalArgumentException.class, () -> HandshakePattern.getInstance("XX").withModifier("fancy")); @@ -253,4 +358,29 @@ void requiresRemoteStaticPublicKey() throws NoSuchPatternException { assertTrue(HandshakePattern.getInstance("KN").requiresRemoteStaticPublicKey(Role.RESPONDER)); assertFalse(HandshakePattern.getInstance("KN").requiresRemoteStaticPublicKey(Role.INITIATOR)); } + + @Test + void messagePatternWithAddedToken() { + final MessagePattern originalPattern = new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E, Token.EE, Token.SE }); + + assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E1, Token.E, Token.EE, Token.SE }), + originalPattern.withAddedToken(Token.E1, 0)); + + assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E, Token.E1, Token.EE, Token.SE }), + originalPattern.withAddedToken(Token.E1, 1)); + + assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E, Token.EE, Token.E1, Token.SE }), + originalPattern.withAddedToken(Token.E1, 2)); + + assertEquals(new HandshakePattern.MessagePattern(Role.INITIATOR, + new Token[] { Token.E, Token.EE, Token.SE, Token.E1 }), + originalPattern.withAddedToken(Token.E1, 3)); + + assertThrows(IllegalArgumentException.class, () -> originalPattern.withAddedToken(Token.E1, -1)); + assertThrows(IllegalArgumentException.class, () -> originalPattern.withAddedToken(Token.E1, 4)); + } } diff --git a/src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java b/src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java index 6ea8ed9..bba2a4f 100644 --- a/src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java +++ b/src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java @@ -15,38 +15,40 @@ class NoiseHandshakeTest { @Test - void getOutboundMessageLength() throws NoSuchPatternException { - final HandshakePattern handshakePattern = HandshakePattern.getInstance("XX"); - - final int publicKeyLength = 56; + void getOutboundMessageLength() throws NoSuchAlgorithmException { + final NoiseHandshake noiseHandshake = + NoiseHandshakeBuilder.forXXInitiator(NoiseKeyAgreement.getInstance("448").generateKeyPair()) + .setComponentsFromProtocolName("Noise_XX_448_AESGCM_SHA256") + .build(); // Expected lengths via https://noiseprotocol.org/noise.html#message-format - assertEquals(56, NoiseHandshake.getOutboundMessageLength(handshakePattern, 0, publicKeyLength, 0)); - assertEquals(144, NoiseHandshake.getOutboundMessageLength(handshakePattern, 1, publicKeyLength, 0)); - assertEquals(88, NoiseHandshake.getOutboundMessageLength(handshakePattern, 2, publicKeyLength, 0)); + assertEquals(56, noiseHandshake.getOutboundMessageLength(0, 0)); + assertEquals(144, noiseHandshake.getOutboundMessageLength(1, 0)); + assertEquals(88, noiseHandshake.getOutboundMessageLength(2, 0)); - assertEquals(59, NoiseHandshake.getOutboundMessageLength(handshakePattern, 0, publicKeyLength, 3)); - assertEquals(149, NoiseHandshake.getOutboundMessageLength(handshakePattern, 1, publicKeyLength, 5)); - assertEquals(95, NoiseHandshake.getOutboundMessageLength(handshakePattern, 2, publicKeyLength, 7)); + assertEquals(59, noiseHandshake.getOutboundMessageLength(0, 3)); + assertEquals(149, noiseHandshake.getOutboundMessageLength(1, 5)); + assertEquals(95, noiseHandshake.getOutboundMessageLength(2, 7)); } @Test - void getPayloadLength() throws NoSuchPatternException { - final HandshakePattern handshakePattern = HandshakePattern.getInstance("XX"); - - final int publicKeyLength = 56; + void getPayloadLength() throws NoSuchAlgorithmException { + final NoiseHandshake noiseHandshake = + NoiseHandshakeBuilder.forXXInitiator(NoiseKeyAgreement.getInstance("448").generateKeyPair()) + .setComponentsFromProtocolName("Noise_XX_448_AESGCM_SHA256") + .build(); // Expected lengths via https://noiseprotocol.org/noise.html#message-format - assertEquals(0, NoiseHandshake.getPayloadLength(handshakePattern, 0, publicKeyLength, 56)); - assertEquals(0, NoiseHandshake.getPayloadLength(handshakePattern, 1, publicKeyLength, 144)); - assertEquals(0, NoiseHandshake.getPayloadLength(handshakePattern, 2, publicKeyLength, 88)); + assertEquals(0, noiseHandshake.getPayloadLength(0, 56)); + assertEquals(0, noiseHandshake.getPayloadLength(1, 144)); + assertEquals(0, noiseHandshake.getPayloadLength(2, 88)); - assertEquals(3, NoiseHandshake.getPayloadLength(handshakePattern, 0, publicKeyLength, 59)); - assertEquals(5, NoiseHandshake.getPayloadLength(handshakePattern, 1, publicKeyLength, 149)); - assertEquals(7, NoiseHandshake.getPayloadLength(handshakePattern, 2, publicKeyLength, 95)); + assertEquals(3, noiseHandshake.getPayloadLength(0, 59)); + assertEquals(5, noiseHandshake.getPayloadLength(1, 149)); + assertEquals(7, noiseHandshake.getPayloadLength(2, 95)); assertThrows(IllegalArgumentException.class, - () -> NoiseHandshake.getPayloadLength(handshakePattern, 0, publicKeyLength, 55)); + () -> noiseHandshake.getPayloadLength(0, 55)); } @Test diff --git a/src/test/java/com/eatthepath/noise/NoiseProtocolIntegrationTest.java b/src/test/java/com/eatthepath/noise/NoiseProtocolIntegrationTest.java index 1b09186..bec1935 100644 --- a/src/test/java/com/eatthepath/noise/NoiseProtocolIntegrationTest.java +++ b/src/test/java/com/eatthepath/noise/NoiseProtocolIntegrationTest.java @@ -8,6 +8,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Named; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -18,6 +19,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.security.*; import java.security.spec.NamedParameterSpec; import java.util.List; @@ -665,4 +667,78 @@ public void nextBytes(final byte[] bytes) { throw new RuntimeException(e); } } + + @Test + void dhkemHfs() throws NoSuchAlgorithmException, AEADBadTagException { + final NoiseHandshake initiatorHandshake = NoiseHandshakeBuilder.forNNHfsInitiator() + .setKeyAgreement("25519") + .setKeyEncapsulationMechanism("DHKEM") + .setCipher("AESGCM") + .setHash("SHA256") + .build(); + + final NoiseHandshake responderHandshake = NoiseHandshakeBuilder.forNNHfsResponder() + .setKeyAgreement("25519") + .setKeyEncapsulationMechanism("DHKEM") + .setCipher("AESGCM") + .setHash("SHA256") + .build(); + + // -> e (with an empty payload) + final byte[] initiatorEMessage = initiatorHandshake.writeMessage((byte[]) null); + responderHandshake.readMessage(initiatorEMessage); + + // <- e, ee (with an empty payload) + final byte[] responderEEeMessage = responderHandshake.writeMessage((byte[]) null); + initiatorHandshake.readMessage(responderEEeMessage); + + assertTrue(initiatorHandshake.isDone()); + assertTrue(responderHandshake.isDone()); + + final NoiseTransport initiatorTransport = initiatorHandshake.toTransport(); + final NoiseTransport responderTransport = responderHandshake.toTransport(); + + final byte[] originalPlaintext = "Original payload!".getBytes(StandardCharsets.UTF_8); + final byte[] originalCiphertext = initiatorTransport.writeMessage(originalPlaintext); + final byte[] decryptedPlaintext = responderTransport.readMessage(originalCiphertext); + + assertArrayEquals(originalPlaintext, decryptedPlaintext); + } + + @Test + void dhkemHfsByteBuffer() throws NoSuchAlgorithmException, AEADBadTagException { + final NoiseHandshake initiatorHandshake = NoiseHandshakeBuilder.forNNHfsInitiator() + .setKeyAgreement("25519") + .setKeyEncapsulationMechanism("DHKEM") + .setCipher("AESGCM") + .setHash("SHA256") + .build(); + + final NoiseHandshake responderHandshake = NoiseHandshakeBuilder.forNNHfsResponder() + .setKeyAgreement("25519") + .setKeyEncapsulationMechanism("DHKEM") + .setCipher("AESGCM") + .setHash("SHA256") + .build(); + + // -> e (with an empty payload) + final ByteBuffer initiatorEMessage = initiatorHandshake.writeMessage((ByteBuffer) null); + responderHandshake.readMessage(initiatorEMessage); + + // <- e, ee (with an empty payload) + final ByteBuffer responderEEeMessage = responderHandshake.writeMessage((ByteBuffer) null); + initiatorHandshake.readMessage(responderEEeMessage); + + assertTrue(initiatorHandshake.isDone()); + assertTrue(responderHandshake.isDone()); + + final NoiseTransport initiatorTransport = initiatorHandshake.toTransport(); + final NoiseTransport responderTransport = responderHandshake.toTransport(); + + final ByteBuffer originalPlaintext = ByteBuffer.wrap("Original payload!".getBytes(StandardCharsets.UTF_8)); + final ByteBuffer originalCiphertext = initiatorTransport.writeMessage(originalPlaintext); + final ByteBuffer decryptedPlaintext = responderTransport.readMessage(originalCiphertext); + + assertEquals(originalPlaintext.rewind(), decryptedPlaintext); + } }