diff --git a/README.rst b/README.rst index c5d2a83..5165e63 100644 --- a/README.rst +++ b/README.rst @@ -46,6 +46,9 @@ This library provides cryptographic operations that are compatible with nilDB no | +------------+------------------------------------+------------------------------------+ | | sum | | additive secret sharing | 32-bit signed integer | | | | | (prime modulus 2^32 + 15) | | +| +------------+------------------------------------+------------------------------------+ +| | redundancy | | Shamir secret sharing | 32-bit signed integer | +| | | | (prime modulus 2^32 + 15) | | +-------------+------------+------------------------------------+------------------------------------+ Installation and Usage diff --git a/pyproject.toml b/pyproject.toml index 9abef4a..edfb462 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ license = {text = "MIT"} readme = "README.rst" requires-python = ">=3.9" dependencies = [ + "lagrange~=3.0", "bcl~=2.3", "pailliers~=0.1" ] diff --git a/src/nilql/nilql.py b/src/nilql/nilql.py index fa9f8a9..f810b2d 100644 --- a/src/nilql/nilql.py +++ b/src/nilql/nilql.py @@ -9,6 +9,7 @@ import secrets import hashlib import hmac +from lagrange import lagrange import bcl import pailliers @@ -27,6 +28,9 @@ _HASH = hashlib.sha512 """Hash function used for HKDF and matching.""" +_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION = 2 +"""Minimum number of shares required to reconstruct a Shamir secret.""" + def _hkdf_extract(salt: bytes, input_key: bytes) -> bytes: """ Extracts a pseudorandom key (PRK) using HMAC with the given salt and input key material. @@ -108,6 +112,57 @@ def _random_int( return minimum + secrets.randbelow(maximum + 1 - minimum) +def _shamirs_eval(poly, x, prime): + """ + Evaluates polynomial (coefficient tuple) at x. + """ + accum = 0 + for coeff in reversed(poly): + accum *= x + accum += coeff + accum %= prime + return accum + +def _shamirs_shares( + secret, + total_shares, + minimum_shares=_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION, + prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS +): + """ + Generates a random Shamir pool for a given secret and returns share points. + """ + if minimum_shares > total_shares: + raise ValueError("Pool secret would be irrecoverable.") + + poly = [secret] + [secrets.randbelow(prime - 1) for _ in range(minimum_shares - 1)] + points = [[i, _shamirs_eval(poly, i, prime)] for i in range(1, total_shares + 1)] + return points + +def _shamirs_recover(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): + """ + Recover the secret from share points. + """ + if len(shares) < _SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION: + raise ValueError( + f'need at least {_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION} shares' + ) + + return lagrange(shares, prime) + +def _shamirs_add(shares1, shares2, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): + """ + Adds two sets of shares pointwise, assuming they use the same x-values. + """ + if len(shares1) != len(shares2): + raise ValueError('shares sets must have the same length') + + return [ + [x1, (y1 + y2) % prime] + for (x1, y1), (x2, y2) in zip(shares1, shares2) + if x1 == x2 + ] + def _pack(b: bytes) -> str: """ Encode a bytes-like object as a Base64 string (for compatibility with JSON). @@ -234,7 +289,7 @@ def generate( if ( (not isinstance(operations, dict)) or - (not set(operations.keys()).issubset({'store', 'match', 'sum'})) + (not set(operations.keys()).issubset({'store', 'match', 'sum', 'redundancy'})) ): raise ValueError('valid operations specification is required') @@ -261,7 +316,7 @@ def generate( 'seed-based derivation of summation-compatible keys ' + 'is not supported for single-node clusters' ) - secret_key['material'] = pailliers.secret(2048) + secret_key['material'] = pailliers.secret(256) else: # Distinct multiplicative mask for each additive share. secret_key['material'] = [ @@ -277,6 +332,25 @@ def generate( for i in range(len(secret_key['cluster']['nodes'])) ] + if secret_key['operations'].get('redundancy'): + if len(secret_key['cluster']['nodes']) == 1: + raise RuntimeError( + 'Redundancy is not supported for single-node clusters' + ) + # Distinct multiplicative mask for each additive share. + secret_key['material'] = [ + _random_int( + 1, + _SECRET_SHARED_SIGNED_INTEGER_MODULUS - 1, + ( + _random_bytes(64, seed, i.to_bytes(64, 'little')) + if seed is not None else + None + ) + ) + for i in range(len(secret_key['cluster']['nodes'])) + ] + return secret_key def dump(self: SecretKey) -> dict: @@ -390,6 +464,15 @@ def generate( # pylint: disable=arguments-differ # Seeds not supported. # Cluster keys contain no cryptographic material. if 'material' in cluster_key: del cluster_key['material'] +# ======= +# # Ensure that the secret key material is the identity value +# # for the supported operation. +# if len(cluster_key['cluster']['nodes']) > 1: +# if cluster_key['operations'].get('store'): +# cluster_key['material'] = bytes(_PLAINTEXT_STRING_BUFFER_LEN_MAX) +# if cluster_key['operations'].get('sum') or cluster_key['operations'].get('redundancy'): +# cluster_key['material'] = 1 +# >>>>>>> d7c678d (feat: Add Shamir secret sharing) return cluster_key @@ -507,7 +590,7 @@ def load(dictionary: PublicKey) -> dict: def encrypt( key: Union[SecretKey, PublicKey], plaintext: Union[int, str, bytes] - ) -> Union[str, Sequence[str], Sequence[int]]: + ) -> Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]]: """ Return the ciphertext obtained by using the supplied key to encrypt the supplied plaintext. @@ -526,7 +609,7 @@ def encrypt( ): raise ValueError('numeric plaintext must be a valid 32-bit signed integer') buffer = _encode(plaintext) - elif 'sum' in key['operations']: + elif ('sum' in key['operations'] or 'redundancy' in key['operations']): # Non-integer cannot be encrypted for summation. raise ValueError('numeric plaintext must be a valid 32-bit signed integer') @@ -614,6 +697,21 @@ def encrypt( ) return shares + if key['operations'].get('redundancy'): + if len(key['cluster']['nodes']) == 1: + raise RuntimeError('redundancy is not supported for single-node clusters') + + # Use Shamir's secret sharing for multiple-node clusters. + masks = [ + key['material'][i] if 'material' in key else 1 + for i in range(len(key['cluster']['nodes'])) + ] + num_nodes = len(key['cluster']['nodes']) + shares = _shamirs_shares(plaintext, num_nodes) + for (i, share) in enumerate(shares): + share[1] = (masks[i] * share[1]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS + + return shares # The below should not occur unless the key's cluster or operations # information is malformed/missing or the plaintext is unsupported. raise ValueError( @@ -622,7 +720,7 @@ def encrypt( def decrypt( key: SecretKey, - ciphertext: Union[str, Sequence[str], Sequence[int]] + ciphertext: Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]] ) -> Union[int, str, bytes]: """ Return the plaintext obtained by using the supplied key to decrypt the @@ -649,6 +747,12 @@ def decrypt( >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}) >>> decrypt(key, encrypt(key, -10)) -10 + >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'redundancy': True}) + >>> decrypt(key, encrypt(key, 123)) + 123 + >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'redundancy': True}) + >>> decrypt(key, encrypt(key, -10)) + -10 An exception is raised if a ciphertext cannot be decrypted using the supplied key (*e.g.*, because one or both are malformed or they are @@ -682,6 +786,14 @@ def decrypt( if ( (not isinstance(ciphertext, Sequence)) or (not ( + all( + ( + isinstance(c, Sequence) and + len(c) == 2 and + all(isinstance(x, int) for x in c) + ) + for c in ciphertext + ) or all(isinstance(c, int) for c in ciphertext) or all(isinstance(c, str) for c in ciphertext) )) @@ -690,7 +802,10 @@ def decrypt( 'secret key requires a valid ciphertext from a multiple-node cluster' ) - if len(key['cluster']['nodes']) != len(ciphertext): + if ( + isinstance(ciphertext, Sequence) and + len(key['cluster']['nodes']) != len(ciphertext) + ) and not key['operations'].get('redundancy'): raise ValueError( 'secret key and ciphertext must have the same associated cluster size' ) @@ -760,6 +875,34 @@ def decrypt( return plaintext + # Decrypt a value that was encrypted in a summation-compatible way. + if key['operations'].get('redundancy'): + if len(key['cluster']['nodes']) == 1: + raise RuntimeError('redundancy is not supported for single-node clusters') + + # For multiple-node clusters, additive secret sharing is used. + inverse_masks = [ + pow( + key['material'][i] if 'material' in key else 1, + _SECRET_SHARED_SIGNED_INTEGER_MODULUS - 2, + _SECRET_SHARED_SIGNED_INTEGER_MODULUS + ) + for i in range(len(key['cluster']['nodes'])) + ] + shares = ciphertext + for (i, share) in enumerate(shares): + share[1] = ( + inverse_masks[share[0] - 1] * shares[i][1] + ) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS + plaintext = _shamirs_recover(shares) + + # Field elements in the "upper half" of the field represent negative + # integers. + if plaintext > _PLAINTEXT_SIGNED_INTEGER_MAX: + plaintext -= _SECRET_SHARED_SIGNED_INTEGER_MODULUS + + return plaintext + raise error def allot( diff --git a/test/test_nilql.py b/test/test_nilql.py index 6bdc5f6..08d29d6 100644 --- a/test/test_nilql.py +++ b/test/test_nilql.py @@ -12,6 +12,19 @@ import nilql +_SECRET_SHARED_SIGNED_INTEGER_MODULUS = (2 ** 32) + 15 + + +def _shamirs_add(shares1, shares2, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): + """ + Adds two sets of shares pointwise, assuming they use the same x-values. + """ + return [ + [x1, (y1 + y2) % prime] + for (x1, y1), (x2, y2) in zip(shares1, shares2) + if x1 == x2 + ] + def to_hash_base64(output: Union[bytes, list[int]]) -> str: """ Helper function for converting a large output from a test into a @@ -126,6 +139,21 @@ def test_key_operations_for_sum_with_multiple_nodes(self): ) self.assertEqual(sk_from_json, sk) + def test_key_operations_for_redundancy_with_multiple_nodes(self): + """ + Test key generate, dump, JSONify, and load for redundancy operation + with multiple nodes. + """ + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + sk_loaded = nilql.SecretKey.load(sk.dump()) + self.assertTrue(isinstance(sk, nilql.SecretKey)) + self.assertEqual(sk_loaded, sk) + + sk_from_json = nilql.SecretKey.load( + json.loads(json.dumps(sk.dump())) + ) + self.assertEqual(sk_from_json, sk) + def test_key_from_seed_for_store_with_single_node(self): """ Test key generation from seed for store operation with a single node. @@ -201,6 +229,21 @@ def test_key_from_seed_for_sum_with_multiple_nodes(self): 'L8RiHNq2EUgt/fDOoUw9QK2NISeUkAkhxHHIPoHPZ84=' ) + def test_key_from_seed_for_redundancy_with_multiple_nodes(self): + """ + Test key generation from seed for redundancy operation with multiple nodes. + """ + sk_from_seed = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}, SEED) + self.assertEqual( + to_hash_base64(sk_from_seed['material']), + 'L8RiHNq2EUgt/fDOoUw9QK2NISeUkAkhxHHIPoHPZ84=' + ) + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + self.assertNotEqual( + to_hash_base64(sk['material']), + 'L8RiHNq2EUgt/fDOoUw9QK2NISeUkAkhxHHIPoHPZ84=' + ) + class TestKeysError(TestCase): """ Tests of errors thrown by methods of cryptographic key classes. @@ -299,6 +342,26 @@ def test_encrypt_decrypt_of_int_for_sum_multiple(self): decrypted = nilql.decrypt(sk, ciphertext) self.assertEqual(decrypted, plaintext) + def test_encrypt_decrypt_of_int_for_redundancy_multiple(self): + """ + Test encryption and decryption for redundancy operation with multiple nodes. + """ + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + plaintext = 123 + ciphertext = nilql.encrypt(sk, plaintext) + decrypted = nilql.decrypt(sk, ciphertext) + self.assertEqual(decrypted, plaintext) + + def test_encrypt_decrypt_of_int_for_redundancy_with_one_failure_multiple(self): + """ + Test encryption and decryption for redundancy operation with multiple nodes. + """ + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + plaintext = 123 + ciphertext = nilql.encrypt(sk, plaintext) + decrypted = nilql.decrypt(sk, ciphertext[1:]) + self.assertEqual(decrypted, plaintext) + class TestCiphertextRepresentations(TestCase): """ Tests of the portable representation of ciphertexts. @@ -327,6 +390,18 @@ def test_ciphertext_representation_for_sum_with_multiple_nodes(self): decrypted = nilql.decrypt(ck, ciphertext) self.assertEqual(decrypted, plaintext) + def test_ciphertext_representation_for_redundancy_with_multiple_nodes(self): + """ + Test that ciphertext representation when storing in a multiple-node cluster. + """ + cluster = {'nodes': [{}, {}, {}]} + operations = {'redundancy': True} + ck = nilql.ClusterKey.generate(cluster, operations) + plaintext = 123 + ciphertext = [[1, 1382717699], [2, 2765435275], [3, 4148152851]] + decrypted = nilql.decrypt(ck, ciphertext) + self.assertEqual(decrypted, plaintext) + class TestFunctionsErrors(TestCase): """ Tests verifying that encryption/decryption methods return expected errors. @@ -469,3 +544,18 @@ def test_workflow_for_secure_sum_with_multiple_nodes(self): ) decrypted = nilql.decrypt(sk, [a3, b3, c3]) self.assertEqual(decrypted, 123 + 456 + 789) + + def test_workflow_for_secure_redundancy_sum_with_multiple_nodes(self): + """ + Test secure summation workflow for a cluster that has multiple nodes. + """ + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + (a0, b0, c0) = nilql.encrypt(sk, 123) + (a1, b1, c1) = nilql.encrypt(sk, 456) + (a2, b2, c2) = nilql.encrypt(sk, 789) + (a3, b3, c3) = _shamirs_add( + _shamirs_add([a0, b0, c0], [a1, b1, c1]), + [a2, b2, c2] + ) + decrypted = nilql.decrypt(sk, [a3, b3, c3]) + self.assertEqual(decrypted, 123 + 456 + 789)