Sets the key agreement algorithm to be used by this handshake. Note that key agreement names may be coupled with + * a key encapsulation mechanism name as described in Section 5 of "KEM-based Hybrid Forward Secrecy for Noise," which + * says, in part:
* - * @param keyAgreementName the name of the Noise key agreement to be used by this handshake + *When the "hfs" modifier is used, the DH name section must contain a KEM algorithm name directly + * following the DH algorithm name, separated by a plus sign.+ * + *
Calling this method with a KEM algorithm name in the {@code keyAgreementName} argument is the equivalent of + * calling this method with the bare key agreement algorithm name, then calling + * {@link #setKeyEncapsulationMechanism(String)} with the KEM algorithm name.
+ * + * @param keyAgreementName the name of the Noise key agreement (possibly including a KEM algorithm) to be used by + * this handshake * * @return a reference to this handshake builder * @@ -151,9 +162,44 @@ public NoiseHandshakeBuilder setHash(final String hashName) throws NoSuchAlgorit * @throws IllegalArgumentException if the given name is not recognized as a Noise key agreement algorithm name * * @see NoiseCipher#getInstance(String) + * @see KEM-based Hybrid Forward Secrecy for Noise, Section 5: The "hfs" modifier */ public NoiseHandshakeBuilder setKeyAgreement(final String keyAgreementName) throws NoSuchAlgorithmException { - this.keyAgreement = NoiseKeyAgreement.getInstance(Objects.requireNonNull(keyAgreementName, "Key agreement algorithm must not be null")); + final int separatorIndex = Objects.requireNonNull(keyAgreementName, "Key agreement algorithm must not be null") + .indexOf('+'); + + final String keyAgreementComponent; + + if (separatorIndex == -1) { + keyAgreementComponent = keyAgreementName; + } else { + keyAgreementComponent = keyAgreementName.substring(0, separatorIndex); + final String keyEncapsulationMechanismComponent = keyAgreementName.substring(separatorIndex); + + this.setKeyEncapsulationMechanism(keyEncapsulationMechanismComponent); + } + + this.keyAgreement = NoiseKeyAgreement.getInstance(keyAgreementComponent); + + return this; + } + + /** + * Sets the key encapsulation mechanism to be used by this handshake. + * + * @param keyEncapsulationMechanismName the name of the Noise key encapsulation mechanism to be used by this handshake + * + * @return a reference ot this handshake builder + * + * @throws NoSuchAlgorithmException if the named algorithm is not supported by the current JVM + * @throws IllegalArgumentException if the given name is not recognized as a Noise key encapsulation mechanism + * + * @see NoiseKeyEncapsulationMechanism#getInstance(String) + */ + public NoiseHandshakeBuilder setKeyEncapsulationMechanism(final String keyEncapsulationMechanismName) throws NoSuchAlgorithmException { + this.keyEncapsulationMechanism = NoiseKeyEncapsulationMechanism.getInstance( + Objects.requireNonNull(keyEncapsulationMechanismName, "Key encapsulation mechanism name must not be null")); + return this; } @@ -183,16 +229,25 @@ public NoiseHandshake build() { throw new IllegalArgumentException("Must set a key agreement algorithm before building a Noise handshake"); } + if (handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism == null) { + throw new IllegalArgumentException("Must set a key encapsulation mechanism before building a Noise handshake with an HFS pattern"); + } else if (!handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism != null) { + throw new IllegalArgumentException("Must not specify a key encapsulation mechanism for a non-HFS handshake pattern"); + } + return new NoiseHandshake(role, handshakePattern, keyAgreement, cipher, hash, + keyEncapsulationMechanism, prologue, localStaticKeyPair, null, + null, remoteStaticPublicKey, null, + null, preSharedKey != null ? List.of(preSharedKey) : null); } @@ -3071,4 +3126,580 @@ public static NoiseHandshakeBuilder forIXPsk2Responder(final KeyPair localStatic throw new AssertionError("Statically-generated handshake pattern not found", e); } } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * NNhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forNNHfsInitiator() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("NNhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * NNhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forNNHfsResponder() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("NNhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in a + * KNhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKNHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("KNhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in a + * KNhfs handshake. + * + * + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKNHfsResponder(final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("KNhfs"), + null, + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * NKhfs handshake. + * + * + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forNKHfsInitiator(final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("NKhfs"), + null, + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * NKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forNKHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("NKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in a + * KKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKKHfsInitiator(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("KKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in a + * KKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKKHfsResponder(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("KKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * NXhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forNXHfsInitiator() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("NXhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * NXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forNXHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("NXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in a + * KXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKXHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("KXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in a + * KXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forKXHfsResponder(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("KXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * XNhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXNHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("XNhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * XNhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forXNHfsResponder() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("XNhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * INhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forINHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("INhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * INhfs handshake. + * + * + * + * + * + * @return a new Noise handshake builder + * + * + */ + public static NoiseHandshakeBuilder forINHfsResponder() { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("INhfs"), + null, + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * XKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXKHfsInitiator(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("XKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * XKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXKHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("XKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * IKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * @param remoteStaticPublicKey the remote static public key for this handshake; must not be {@code null} + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forIKHfsInitiator(final KeyPair localStaticKeyPair, final PublicKey remoteStaticPublicKey) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("IKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + Objects.requireNonNull(remoteStaticPublicKey, "Remote static public key must not be null"), + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * IKhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forIKHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("IKhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * XXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXXHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("XXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * XXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forXXHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("XXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the initiator in an + * IXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forIXHfsInitiator(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.INITIATOR, + HandshakePattern.getInstance("IXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } + + /** + * Constructs a new Noise handshake builder for the responder in an + * IXhfs handshake. + * + * @param localStaticKeyPair the local static key pair for this handshake; must not be {@code null} + * + * + * + * @return a new Noise handshake builder + * + * @throws NullPointerException if any required key {@code null} + */ + public static NoiseHandshakeBuilder forIXHfsResponder(final KeyPair localStaticKeyPair) { + try { + return new NoiseHandshakeBuilder(NoiseHandshake.Role.RESPONDER, + HandshakePattern.getInstance("IXhfs"), + Objects.requireNonNull(localStaticKeyPair, "Local static key pair must not be null"), + null, + null); + } catch (final NoSuchPatternException e) { + throw new AssertionError("Statically-generated handshake pattern not found", e); + } + } } diff --git a/src/main/java/com/eatthepath/noise/component/AbstractXECKeyAgreement.java b/src/main/java/com/eatthepath/noise/component/AbstractXECKeyAgreement.java index 254352c..169b64e 100644 --- a/src/main/java/com/eatthepath/noise/component/AbstractXECKeyAgreement.java +++ b/src/main/java/com/eatthepath/noise/component/AbstractXECKeyAgreement.java @@ -42,57 +42,22 @@ public byte[] generateSecret(final PrivateKey privateKey, final PublicKey public @Override public byte[] serializePublicKey(final PublicKey publicKey) { - // This is a little hacky, but the structure for an X.509 public key defines the order in which its elements appear. - // The first part of the key, which defines the algorithm and its parameters, is always the same for keys of the - // same type, and the last N bytes are the literal key material. - final byte[] serializedPublicKey = new byte[getPublicKeyLength()]; - System.arraycopy(publicKey.getEncoded(), getX509Prefix().length, serializedPublicKey, 0, getPublicKeyLength()); - - return serializedPublicKey; + return XECUtil.serializePublicKey(publicKey, getPublicKeyLength(), getX509Prefix()); } @Override public PublicKey deserializePublicKey(final byte[] publicKeyBytes) { - final int publicKeyLength = getPublicKeyLength(); - - if (publicKeyBytes.length != publicKeyLength) { - throw new IllegalArgumentException("Unexpected serialized public key length"); - } - - final byte[] x509Prefix = getX509Prefix(); - final byte[] x509Bytes = new byte[publicKeyLength + x509Prefix.length]; - System.arraycopy(x509Prefix, 0, x509Bytes, 0, x509Prefix.length); - System.arraycopy(publicKeyBytes, 0, x509Bytes, x509Prefix.length, publicKeyLength); - - try { - return keyFactory.generatePublic(new X509EncodedKeySpec(x509Bytes, keyFactory.getAlgorithm())); - } catch (final InvalidKeySpecException e) { - throw new IllegalArgumentException("Invalid key", e); - } + return XECUtil.deserializePublicKey(publicKeyBytes, getPublicKeyLength(), getX509Prefix(), keyFactory); } @Override public void checkPublicKey(final PublicKey publicKey) throws InvalidKeyException { - checkKey(publicKey); + XECUtil.checkKey(publicKey, keyAgreement.getAlgorithm()); } @Override public void checkKeyPair(final KeyPair keyPair) throws InvalidKeyException { - checkKey(keyPair.getPublic()); - checkKey(keyPair.getPrivate()); - } - - private void checkKey(final Key key) throws InvalidKeyException { - if (key instanceof XECKey xecKey) { - if (xecKey.getParams() instanceof NamedParameterSpec namedParameterSpec) { - if (!keyAgreement.getAlgorithm().equals(namedParameterSpec.getName())) { - throw new InvalidKeyException("Unexpected key algorithm: " + namedParameterSpec.getName()); - } - } else { - throw new InvalidKeyException("Unexpected key parameter type: " + xecKey.getParams().getClass()); - } - } else { - throw new InvalidKeyException("Unexpected key type: " + key.getClass()); - } + XECUtil.checkKey(keyPair.getPublic(), keyAgreement.getAlgorithm()); + XECUtil.checkKey(keyPair.getPrivate(), keyAgreement.getAlgorithm()); } } diff --git a/src/main/java/com/eatthepath/noise/component/DhkemKeyEncapsulationMechanism.java b/src/main/java/com/eatthepath/noise/component/DhkemKeyEncapsulationMechanism.java new file mode 100644 index 0000000..57ff3ec --- /dev/null +++ b/src/main/java/com/eatthepath/noise/component/DhkemKeyEncapsulationMechanism.java @@ -0,0 +1,81 @@ +package com.eatthepath.noise.component; + +import javax.crypto.DecapsulateException; +import javax.crypto.KEM; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import java.security.*; + +public class DhkemKeyEncapsulationMechanism implements NoiseKeyEncapsulationMechanism { + + private final KEM kem; + private final KeyPairGenerator keyPairGenerator; + private final KeyFactory keyFactory; + + public DhkemKeyEncapsulationMechanism() throws NoSuchAlgorithmException { + this.keyPairGenerator = KeyPairGenerator.getInstance("X25519"); + this.keyFactory = KeyFactory.getInstance("X25519"); + this.kem = KEM.getInstance("DHKEM"); + } + + @Override + public String getName() { + return "DHKEM"; + } + + @Override + public KeyPair generateKeyPair() { + return keyPairGenerator.generateKeyPair(); + } + + @Override + public KEM.Encapsulated encapsulate(final PublicKey publicKey) { + try { + return kem.newEncapsulator(publicKey).encapsulate(); + } catch (final InvalidKeyException e) { + throw new IllegalArgumentException("Invalid public key for encapsulation", e); + } + } + + @Override + public byte[] decapsulate(final PrivateKey privateKey, final byte[] encapsulation) { + try { + return serializeSharedSecret(kem.newDecapsulator(privateKey).decapsulate(encapsulation)); + } catch (final DecapsulateException e) { + throw new IllegalArgumentException("Invalid encapsulation", e); + } catch (final InvalidKeyException e) { + throw new IllegalArgumentException("Invalid private key for decapsulation", e); + } + } + + @Override + public int getPublicKeyLength() { + return 32; + } + + @Override + public int getEncapsulationLength() { + return 32; + } + + @Override + public byte[] serializePublicKey(final PublicKey publicKey) { + return XECUtil.serializePublicKey(publicKey, getPublicKeyLength(), XECUtil.X25519_X509_PREFIX); + } + + @Override + public PublicKey deserializePublicKey(final byte[] publicKeyBytes) { + return XECUtil.deserializePublicKey(publicKeyBytes, getPublicKeyLength(), XECUtil.X25519_X509_PREFIX, keyFactory); + } + + @Override + public byte[] serializeSharedSecret(final SecretKey sharedSecret) { + // For DHKEM, the shared secret has a "raw" encoding + return sharedSecret.getEncoded(); + } + + @Override + public SecretKey deserializeSharedSecret(final byte[] sharedSecretBytes) { + return new SecretKeySpec(sharedSecretBytes, "Generic"); + } +} diff --git a/src/main/java/com/eatthepath/noise/component/NoiseKeyEncapsulationMechanism.java b/src/main/java/com/eatthepath/noise/component/NoiseKeyEncapsulationMechanism.java new file mode 100644 index 0000000..415f902 --- /dev/null +++ b/src/main/java/com/eatthepath/noise/component/NoiseKeyEncapsulationMechanism.java @@ -0,0 +1,45 @@ +package com.eatthepath.noise.component; + +import javax.crypto.KEM; +import javax.crypto.KEM; +import javax.crypto.SecretKey; +import java.security.KeyPair; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.PublicKey; + +public interface NoiseKeyEncapsulationMechanism { + + static NoiseKeyEncapsulationMechanism getInstance(final String name) throws NoSuchAlgorithmException { + return switch (name) { + case "DHKEM" -> new DhkemKeyEncapsulationMechanism(); + default -> throw new IllegalArgumentException("Unrecognized key encapsulation method name: " + name); + }; + } + + String getName(); + + KeyPair generateKeyPair(); + + /** + * + * @param publicKey the remote public key with which to encapsulate a shared secret + * + * @return an encapsulated shared secret key + */ + KEM.Encapsulated encapsulate(PublicKey publicKey); + + byte[] decapsulate(PrivateKey privateKey, byte[] encapsulation); + + int getPublicKeyLength(); + + int getEncapsulationLength(); + + byte[] serializePublicKey(PublicKey publicKey); + + PublicKey deserializePublicKey(byte[] publicKeyBytes); + + byte[] serializeSharedSecret(SecretKey sharedSecret); + + SecretKey deserializeSharedSecret(byte[] sharedSecretBytes); +} diff --git a/src/main/java/com/eatthepath/noise/component/X25519KeyAgreement.java b/src/main/java/com/eatthepath/noise/component/X25519KeyAgreement.java index 4690282..b725631 100644 --- a/src/main/java/com/eatthepath/noise/component/X25519KeyAgreement.java +++ b/src/main/java/com/eatthepath/noise/component/X25519KeyAgreement.java @@ -7,7 +7,6 @@ class X25519KeyAgreement extends AbstractXECKeyAgreement { private static final String ALGORITHM = "X25519"; - private static final byte[] X509_PREFIX = HexFormat.of().parseHex("302a300506032b656e032100"); public X25519KeyAgreement() throws NoSuchAlgorithmException { super(KeyAgreement.getInstance(ALGORITHM), KeyPairGenerator.getInstance(ALGORITHM), KeyFactory.getInstance(ALGORITHM)); @@ -25,6 +24,6 @@ public int getPublicKeyLength() { @Override protected byte[] getX509Prefix() { - return X509_PREFIX; + return XECUtil.X25519_X509_PREFIX; } } diff --git a/src/main/java/com/eatthepath/noise/component/X448KeyAgreement.java b/src/main/java/com/eatthepath/noise/component/X448KeyAgreement.java index e44d36e..3a281a9 100644 --- a/src/main/java/com/eatthepath/noise/component/X448KeyAgreement.java +++ b/src/main/java/com/eatthepath/noise/component/X448KeyAgreement.java @@ -7,7 +7,6 @@ class X448KeyAgreement extends AbstractXECKeyAgreement { private static final String ALGORITHM = "X448"; - private static final byte[] X509_PREFIX = HexFormat.of().parseHex("3042300506032b656f033900"); public X448KeyAgreement() throws NoSuchAlgorithmException { super(KeyAgreement.getInstance(ALGORITHM), KeyPairGenerator.getInstance(ALGORITHM), KeyFactory.getInstance(ALGORITHM)); @@ -25,6 +24,6 @@ public int getPublicKeyLength() { @Override protected byte[] getX509Prefix() { - return X509_PREFIX; + return XECUtil.X448_X509_PREFIX; } } diff --git a/src/main/java/com/eatthepath/noise/component/XECUtil.java b/src/main/java/com/eatthepath/noise/component/XECUtil.java new file mode 100644 index 0000000..ff9c00f --- /dev/null +++ b/src/main/java/com/eatthepath/noise/component/XECUtil.java @@ -0,0 +1,64 @@ +package com.eatthepath.noise.component; + +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.KeyFactory; +import java.security.PublicKey; +import java.security.interfaces.XECKey; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.NamedParameterSpec; +import java.security.spec.X509EncodedKeySpec; +import java.util.HexFormat; + +class XECUtil { + + static final byte[] X25519_X509_PREFIX = HexFormat.of().parseHex("302a300506032b656e032100"); + static final byte[] X448_X509_PREFIX = HexFormat.of().parseHex("3042300506032b656f033900"); + + private XECUtil() { + } + + static byte[] serializePublicKey(final PublicKey publicKey, final int publicKeyLength, final byte[] x509Prefix) { + // This is a little hacky, but the structure for an X.509 public key defines the order in which its elements appear. + // The first part of the key, which defines the algorithm and its parameters, is always the same for keys of the + // same type, and the last N bytes are the literal key material. + final byte[] serializedPublicKey = new byte[publicKeyLength]; + System.arraycopy(publicKey.getEncoded(), x509Prefix.length, serializedPublicKey, 0, publicKeyLength); + + return serializedPublicKey; + } + + static PublicKey deserializePublicKey(final byte[] publicKeyBytes, + final int publicKeyLength, + final byte[] x509Prefix, + final KeyFactory keyFactory) { + + if (publicKeyBytes.length != publicKeyLength) { + throw new IllegalArgumentException("Unexpected serialized public key length"); + } + + final byte[] x509Bytes = new byte[publicKeyLength + x509Prefix.length]; + System.arraycopy(x509Prefix, 0, x509Bytes, 0, x509Prefix.length); + System.arraycopy(publicKeyBytes, 0, x509Bytes, x509Prefix.length, publicKeyLength); + + try { + return keyFactory.generatePublic(new X509EncodedKeySpec(x509Bytes, keyFactory.getAlgorithm())); + } catch (final InvalidKeySpecException e) { + throw new IllegalArgumentException("Invalid key", e); + } + } + + static void checkKey(final Key key, final String algorithm) throws InvalidKeyException { + if (key instanceof XECKey xecKey) { + if (xecKey.getParams() instanceof NamedParameterSpec namedParameterSpec) { + if (!algorithm.equals(namedParameterSpec.getName())) { + throw new InvalidKeyException("Unexpected key algorithm: " + namedParameterSpec.getName()); + } + } else { + throw new InvalidKeyException("Unexpected key parameter type: " + xecKey.getParams().getClass()); + } + } else { + throw new InvalidKeyException("Unexpected key type: " + key.getClass()); + } + } +} diff --git a/src/main/resources/com/eatthepath/noise/NoiseHandshakeBuilder.java.template b/src/main/resources/com/eatthepath/noise/NoiseHandshakeBuilder.java.template index 8ad9782..3cd94c3 100644 --- a/src/main/resources/com/eatthepath/noise/NoiseHandshakeBuilder.java.template +++ b/src/main/resources/com/eatthepath/noise/NoiseHandshakeBuilder.java.template @@ -10,18 +10,20 @@ import java.util.List; import java.util.Objects; /** - * A Noise handshake builder constructs {@link NoiseHandshake} instances with known handshake patterns and roles. + *A Noise handshake builder constructs {@link NoiseHandshake} instances with known handshake patterns and roles. * In contrast to {@link NamedProtocolHandshakeBuilder}, this builder provides compile-time checks that all required * keys are provided, but places the burden of selecting protocol components (key agreement algorithms, ciphers, and - * hash algorithms) on the caller. - *
- * Callers may specify the cryptographic components of a Noise protocol by providing a full Noise protocol name… - *
+ * hash algorithms) on the caller.
+ * + *Callers may specify the cryptographic components of a Noise protocol by providing a full Noise protocol name…
+ * * {@snippet file="NoiseHandshakeBuilderExample.java" region="ik-handshake-protocol-name"} - *- * …or by specifying the name of each component individually: - *
+ * + *
…or by specifying the name of each component individually:
+ * * {@snippet file="NoiseHandshakeBuilderExample.java" region="ik-handshake-component-names"} + * + * @see NamedProtocolHandshakeBuilder */ @SuppressWarnings("unused") public class NoiseHandshakeBuilder { @@ -38,6 +40,7 @@ public class NoiseHandshakeBuilder { @Nullable private NoiseCipher cipher; @Nullable private NoiseHash hash; @Nullable private NoiseKeyAgreement keyAgreement; + @Nullable private NoiseKeyEncapsulationMechanism keyEncapsulationMechanism; private NoiseHandshakeBuilder(final NoiseHandshake.Role role, final HandshakePattern handshakePattern, @@ -139,9 +142,19 @@ public class NoiseHandshakeBuilder { } /** - * Sets the key agreement algorithm to be used by this handshake. + *Sets the key agreement algorithm to be used by this handshake. Note that key agreement names may be coupled with + * a key encapsulation mechanism name as described in Section 5 of "KEM-based Hybrid Forward Secrecy for Noise," which + * says, in part:
* - * @param keyAgreementName the name of the Noise key agreement to be used by this handshake + *When the "hfs" modifier is used, the DH name section must contain a KEM algorithm name directly + * following the DH algorithm name, separated by a plus sign.+ * + *
Calling this method with a KEM algorithm name in the {@code keyAgreementName} argument is the equivalent of + * calling this method with the bare key agreement algorithm name, then calling + * {@link #setKeyEncapsulationMechanism(String)} (String)} with the KEM algorithm name.
+ * + * @param keyAgreementName the name of the Noise key agreement (possibly including a KEM algorithm) to be used by + * this handshake * * @return a reference to this handshake builder * @@ -149,9 +162,44 @@ public class NoiseHandshakeBuilder { * @throws IllegalArgumentException if the given name is not recognized as a Noise key agreement algorithm name * * @see NoiseCipher#getInstance(String) + * @see KEM-based Hybrid Forward Secrecy for Noise, Section 5: The "hfs" modifier */ public NoiseHandshakeBuilder setKeyAgreement(final String keyAgreementName) throws NoSuchAlgorithmException { - this.keyAgreement = NoiseKeyAgreement.getInstance(Objects.requireNonNull(keyAgreementName, "Key agreement algorithm must not be null")); + final int separatorIndex = Objects.requireNonNull(keyAgreementName, "Key agreement algorithm must not be null") + .indexOf('+'); + + final String keyAgreementComponent; + + if (separatorIndex == -1) { + keyAgreementComponent = keyAgreementName; + } else { + keyAgreementComponent = keyAgreementName.substring(0, separatorIndex); + final String keyEncapsulationMechanismComponent = keyAgreementName.substring(separatorIndex); + + this.setKeyEncapsulationMechanism(keyEncapsulationMechanismComponent); + } + + this.keyAgreement = NoiseKeyAgreement.getInstance(keyAgreementComponent); + + return this; + } + + /** + * Sets the key encapsulation mechanism to be used by this handshake. + * + * @param keyEncapsulationMechanismName the name of the Noise key encapsulation mechanism to be used by this handshake + * + * @return a reference ot this handshake builder + * + * @throws NoSuchAlgorithmException if the named algorithm is not supported by the current JVM + * @throws IllegalArgumentException if the given name is not recognized as a Noise key encapsulation mechanism + * + * @see NoiseKeyEncapsulationMechanism#getInstance(String) + */ + public NoiseHandshakeBuilder setKeyEncapsulationMechanism(final String keyEncapsulationMechanismName) throws NoSuchAlgorithmException { + this.keyEncapsulationMechanism = NoiseKeyEncapsulationMechanism.getInstance( + Objects.requireNonNull(keyEncapsulationMechanismName, "Key encapsulation mechanism name must not be null")); + return this; } @@ -181,16 +229,25 @@ public class NoiseHandshakeBuilder { throw new IllegalArgumentException("Must set a key agreement algorithm before building a Noise handshake"); } + if (handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism == null) { + throw new IllegalArgumentException("Must set a key encapsulation mechanism before building a Noise handshake with an HFS pattern"); + } else if (!handshakePattern.requiresKeyEncapsulationMechanism() && keyEncapsulationMechanism != null) { + throw new IllegalArgumentException("Must not specify a key encapsulation mechanism for a non-HFS handshake pattern"); + } + return new NoiseHandshake(role, handshakePattern, keyAgreement, cipher, hash, + keyEncapsulationMechanism, prologue, localStaticKeyPair, null, + null, remoteStaticPublicKey, null, + null, preSharedKey != null ? List.of(preSharedKey) : null); } diff --git a/src/test/java/com/eatthepath/noise/HandshakePatternTest.java b/src/test/java/com/eatthepath/noise/HandshakePatternTest.java index 24f7284..d4ee2b0 100644 --- a/src/test/java/com/eatthepath/noise/HandshakePatternTest.java +++ b/src/test/java/com/eatthepath/noise/HandshakePatternTest.java @@ -169,6 +169,111 @@ void withPskModifier() throws NoSuchPatternException { } } + @ParameterizedTest + @MethodSource + void withHfsModifier(final HandshakePattern expectedHfsPattern) throws NoSuchPatternException { + final String fundamentalPatternName = HandshakePattern.getFundamentalPatternName(expectedHfsPattern.getName()); + + assertEquals(expectedHfsPattern, HandshakePattern.getInstance(fundamentalPatternName).withModifier("hfs")); + } + + private static List