From 1d7cd5e7ada0a32d72032830850fba6b1ebd9832 Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Tue, 11 Feb 2025 11:49:56 -0500 Subject: [PATCH 1/3] feat: Add Shamir secret sharing --- README.rst | 3 + src/nilql/nilql.py | 180 +++++++++++++++++++++++++++++++++++++++++++-- test/test_nilql.py | 79 ++++++++++++++++++++ 3 files changed, 254 insertions(+), 8 deletions(-) diff --git a/README.rst b/README.rst index c5d2a83..5165e63 100644 --- a/README.rst +++ b/README.rst @@ -46,6 +46,9 @@ This library provides cryptographic operations that are compatible with nilDB no | +------------+------------------------------------+------------------------------------+ | | sum | | additive secret sharing | 32-bit signed integer | | | | | (prime modulus 2^32 + 15) | | +| +------------+------------------------------------+------------------------------------+ +| | redundancy | | Shamir secret sharing | 32-bit signed integer | +| | | | (prime modulus 2^32 + 15) | | +-------------+------------+------------------------------------+------------------------------------+ Installation and Usage diff --git a/src/nilql/nilql.py b/src/nilql/nilql.py index fa9f8a9..006e0d6 100644 --- a/src/nilql/nilql.py +++ b/src/nilql/nilql.py @@ -3,7 +3,7 @@ replies. """ from __future__ import annotations -from typing import Union, Optional, Sequence +from typing import Union, Optional, Sequence, Tuple import doctest import base64 import secrets @@ -27,6 +27,10 @@ _HASH = hashlib.sha512 """Hash function used for HKDF and matching.""" +_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION = 2 +"""Minimum number of shares required to reconstruct a Shamir secret.""" + + def _hkdf_extract(salt: bytes, input_key: bytes) -> bytes: """ Extracts a pseudorandom key (PRK) using HMAC with the given salt and input key material. @@ -108,6 +112,80 @@ def _random_int( return minimum + secrets.randbelow(maximum + 1 - minimum) +def _eval_at(poly, x, prime): + """Evaluates polynomial (coefficient tuple) at x.""" + accum = 0 + for coeff in reversed(poly): + accum *= x + accum += coeff + accum %= prime + return accum + +def _shamir_secret_share( + secret, + total_shares, + minimum_shares=_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION, + prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS +): + """Generates a random shamir pool for a given secret, returns share points.""" + if minimum_shares > total_shares: + raise ValueError("Pool secret would be irrecoverable.") + poly = [secret] + [secrets.randbelow(prime - 1) for _ in range(minimum_shares - 1)] + points = [(i, _eval_at(poly, i, prime)) for i in range(1, total_shares + 1)] + return points + +def _extended_gcd(a, b): + """Extended Euclidean algorithm for modular inverse.""" + x, last_x = 0, 1 + y, last_y = 1, 0 + while b != 0: + quot = a // b + a, b = b, a % b + x, last_x = last_x - quot * x, x + y, last_y = last_y - quot * y, y + return last_x, last_y + +def _divmod(num, den, p): + """Compute num / den modulo prime p.""" + inv, _ = _extended_gcd(den, p) + return num * inv % p + +def _lagrange_interpolate(x, x_s, y_s, p): + """Find the y-value for the given x using Lagrange interpolation.""" + k = len(x_s) + assert k == len(set(x_s)), "Points must be distinct" + + def _multiply(vals): + accum = 1 + for v in vals: + accum *= v + return accum + + nums, dens = [], [] + for i in range(k): + others = list(x_s) + cur = others.pop(i) + nums.append(_multiply(x - o for o in others)) + dens.append(_multiply(cur - o for o in others)) + + den = _multiply(dens) + num = sum(_divmod(nums[i] * den * y_s[i] % p, dens[i], p) for i in range(k)) + return (_divmod(num, den, p) + p) % p + +def _recover_shamir_secret(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): + """Recover the secret from share points.""" + if len(shares) < _SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION: + raise ValueError(f"Need at least {_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION} shares") + x_s, y_s = zip(*shares) + return _lagrange_interpolate(0, x_s, y_s, prime) + +def add_shamir_shares(shares1, shares2, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): + """Adds two sets of shares pointwise, assuming they use the same x-values.""" + if len(shares1) != len(shares2): + raise ValueError("Shares sets must have the same length") + added_shares = [(x1, (y1 + y2) % prime) for (x1, y1), (x2, y2) in zip(shares1, shares2) if x1 == x2] + return added_shares + def _pack(b: bytes) -> str: """ Encode a bytes-like object as a Base64 string (for compatibility with JSON). @@ -234,7 +312,7 @@ def generate( if ( (not isinstance(operations, dict)) or - (not set(operations.keys()).issubset({'store', 'match', 'sum'})) + (not set(operations.keys()).issubset({'store', 'match', 'sum', 'redundancy'})) ): raise ValueError('valid operations specification is required') @@ -261,7 +339,7 @@ def generate( 'seed-based derivation of summation-compatible keys ' + 'is not supported for single-node clusters' ) - secret_key['material'] = pailliers.secret(2048) + secret_key['material'] = pailliers.secret(256) else: # Distinct multiplicative mask for each additive share. secret_key['material'] = [ @@ -277,6 +355,24 @@ def generate( for i in range(len(secret_key['cluster']['nodes'])) ] + if secret_key['operations'].get('redundancy'): + if len(secret_key['cluster']['nodes']) == 1: + raise RuntimeError( + 'Redundancy is not supported for single-node clusters' + ) + secret_key['material'] = [ + _random_int( + 1, + _SECRET_SHARED_SIGNED_INTEGER_MODULUS - 1, + ( + _random_bytes(64, seed, i.to_bytes(64, 'little')) + if seed is not None else + None + ) + ) + for i in range(len(secret_key['cluster']['nodes'])) + ] + return secret_key def dump(self: SecretKey) -> dict: @@ -390,6 +486,15 @@ def generate( # pylint: disable=arguments-differ # Seeds not supported. # Cluster keys contain no cryptographic material. if 'material' in cluster_key: del cluster_key['material'] +# ======= +# # Ensure that the secret key material is the identity value +# # for the supported operation. +# if len(cluster_key['cluster']['nodes']) > 1: +# if cluster_key['operations'].get('store'): +# cluster_key['material'] = bytes(_PLAINTEXT_STRING_BUFFER_LEN_MAX) +# if cluster_key['operations'].get('sum') or cluster_key['operations'].get('redundancy'): +# cluster_key['material'] = 1 +# >>>>>>> d7c678d (feat: Add Shamir secret sharing) return cluster_key @@ -507,7 +612,7 @@ def load(dictionary: PublicKey) -> dict: def encrypt( key: Union[SecretKey, PublicKey], plaintext: Union[int, str, bytes] - ) -> Union[str, Sequence[str], Sequence[int]]: + ) -> Union[str, Sequence[str], Sequence[int], Sequence[Tuple[int, int]]]: """ Return the ciphertext obtained by using the supplied key to encrypt the supplied plaintext. @@ -526,8 +631,8 @@ def encrypt( ): raise ValueError('numeric plaintext must be a valid 32-bit signed integer') buffer = _encode(plaintext) - elif 'sum' in key['operations']: - # Non-integer cannot be encrypted for summation. + elif ('sum' in key['operations'] or + 'redundancy' in key['operations']): # Non-integer cannot be encrypted for summation. raise ValueError('numeric plaintext must be a valid 32-bit signed integer') # Encode string or binary data for storage or matching. @@ -614,6 +719,24 @@ def encrypt( ) return shares + if key['operations'].get('redundancy'): + quantity = len(key['cluster']['nodes']) + if quantity == 1: + raise RuntimeError( + 'Redundancy is not supported for single-node clusters' + ) + # Use Shamir secret sharing for multiple-node clusters. + shamir_shares = _shamir_secret_share(plaintext, quantity) + masks = [ + key['material'][i] if 'material' in key else 1 + for i in range(quantity) + ] + shares = [ + (x, int((y * mask) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS)) + for ((x, y), mask) in zip(shamir_shares, masks) + ] + return shares + # The below should not occur unless the key's cluster or operations # information is malformed/missing or the plaintext is unsupported. raise ValueError( @@ -622,7 +745,7 @@ def encrypt( def decrypt( key: SecretKey, - ciphertext: Union[str, Sequence[str], Sequence[int]] + ciphertext: Union[str, Sequence[str], Sequence[int], Sequence[Tuple[int, int]]] ) -> Union[int, str, bytes]: """ Return the plaintext obtained by using the supplied key to decrypt the @@ -649,6 +772,12 @@ def decrypt( >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True}) >>> decrypt(key, encrypt(key, -10)) -10 + >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'redundancy': True}) + >>> decrypt(key, encrypt(key, 123)) + 123 + >>> key = SecretKey.generate({'nodes': [{}, {}]}, {'redundancy': True}) + >>> decrypt(key, encrypt(key, -10)) + -10 An exception is raised if a ciphertext cannot be decrypted using the supplied key (*e.g.*, because one or both are malformed or they are @@ -682,6 +811,11 @@ def decrypt( if ( (not isinstance(ciphertext, Sequence)) or (not ( + all( + isinstance(c, tuple) and len(c) == 2 and + all(isinstance(x, int) for x in c) + for c in ciphertext + ) or all(isinstance(c, int) for c in ciphertext) or all(isinstance(c, str) for c in ciphertext) )) @@ -690,7 +824,10 @@ def decrypt( 'secret key requires a valid ciphertext from a multiple-node cluster' ) - if len(key['cluster']['nodes']) != len(ciphertext): + if ( + isinstance(ciphertext, Sequence) and + len(key['cluster']['nodes']) != len(ciphertext) + ) and not key['operations'].get('redundancy'): raise ValueError( 'secret key and ciphertext must have the same associated cluster size' ) @@ -760,6 +897,33 @@ def decrypt( return plaintext + if key['operations'].get('redundancy'): + if len(key['cluster']['nodes']) == 1: + raise RuntimeError( + 'Redundancy is not supported for single-node clusters' + ) + + inverse_masks = [ + pow( + key['material'][i] if 'material' in key else 1, + _SECRET_SHARED_SIGNED_INTEGER_MODULUS - 2, + _SECRET_SHARED_SIGNED_INTEGER_MODULUS + ) + for i in range(len(key['cluster']['nodes'])) + ] + unmasked_shares = [ + (x, (y * inverse_masks[x - 1]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS) + for (x, y) in ciphertext + ] + + # Use Shamir secret sharing for multiple-node clusters. + plaintext = _recover_shamir_secret(unmasked_shares) + + if plaintext > _PLAINTEXT_SIGNED_INTEGER_MAX: + plaintext -= _SECRET_SHARED_SIGNED_INTEGER_MODULUS + + return plaintext + raise error def allot( diff --git a/test/test_nilql.py b/test/test_nilql.py index 6bdc5f6..1cac9c0 100644 --- a/test/test_nilql.py +++ b/test/test_nilql.py @@ -12,6 +12,8 @@ import nilql +from src.nilql.nilql import add_shamir_shares + def to_hash_base64(output: Union[bytes, list[int]]) -> str: """ Helper function for converting a large output from a test into a @@ -126,6 +128,21 @@ def test_key_operations_for_sum_with_multiple_nodes(self): ) self.assertEqual(sk_from_json, sk) + def test_key_operations_for_redundancy_with_multiple_nodes(self): + """ + Test key generate, dump, JSONify, and load for redundancy operation + with multiple nodes. + """ + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + sk_loaded = nilql.SecretKey.load(sk.dump()) + self.assertTrue(isinstance(sk, nilql.SecretKey)) + self.assertTrue(sk == sk_loaded) + + sk_from_json = nilql.SecretKey.load( + json.loads(json.dumps(sk.dump())) + ) + self.assertTrue(sk == sk_from_json) + def test_key_from_seed_for_store_with_single_node(self): """ Test key generation from seed for store operation with a single node. @@ -201,6 +218,21 @@ def test_key_from_seed_for_sum_with_multiple_nodes(self): 'L8RiHNq2EUgt/fDOoUw9QK2NISeUkAkhxHHIPoHPZ84=' ) + def test_key_from_seed_for_redundancy_with_multiple_nodes(self): + """ + Test key generation from seed for redundancy operation with multiple nodes. + """ + sk_from_seed = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}, SEED) + self.assertEqual( + to_hash_base64(sk_from_seed['material']), + 'L8RiHNq2EUgt/fDOoUw9QK2NISeUkAkhxHHIPoHPZ84=' + ) + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + self.assertNotEqual( + to_hash_base64(sk['material']), + 'L8RiHNq2EUgt/fDOoUw9QK2NISeUkAkhxHHIPoHPZ84=' + ) + class TestKeysError(TestCase): """ Tests of errors thrown by methods of cryptographic key classes. @@ -299,6 +331,26 @@ def test_encrypt_decrypt_of_int_for_sum_multiple(self): decrypted = nilql.decrypt(sk, ciphertext) self.assertEqual(decrypted, plaintext) + def test_encrypt_decrypt_of_int_for_redundancy_multiple(self): + """ + Test encryption and decryption for redundancy operation with multiple nodes. + """ + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + plaintext = 123 + ciphertext = nilql.encrypt(sk, plaintext) + decrypted = nilql.decrypt(sk, ciphertext) + self.assertEqual(decrypted, plaintext) + + def test_encrypt_decrypt_of_int_for_redundancy_with_one_failure_multiple(self): + """ + Test encryption and decryption for redundancy operation with multiple nodes. + """ + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + plaintext = 123 + ciphertext = nilql.encrypt(sk, plaintext) + decrypted = nilql.decrypt(sk, ciphertext[1:]) + self.assertEqual(decrypted, plaintext) + class TestCiphertextRepresentations(TestCase): """ Tests of the portable representation of ciphertexts. @@ -327,6 +379,18 @@ def test_ciphertext_representation_for_sum_with_multiple_nodes(self): decrypted = nilql.decrypt(ck, ciphertext) self.assertEqual(decrypted, plaintext) + def test_ciphertext_representation_for_redundancy_with_multiple_nodes(self): + """ + Test that ciphertext representation when storing in a multiple-node cluster. + """ + cluster = {'nodes': [{}, {}, {}]} + operations = {'redundancy': True} + ck = nilql.ClusterKey.generate(cluster, operations) + plaintext = 123 + ciphertext = [(1, 1382717699), (2, 2765435275), (3, 4148152851)] + decrypted = nilql.decrypt(ck, ciphertext) + self.assertEqual(decrypted, plaintext) + class TestFunctionsErrors(TestCase): """ Tests verifying that encryption/decryption methods return expected errors. @@ -469,3 +533,18 @@ def test_workflow_for_secure_sum_with_multiple_nodes(self): ) decrypted = nilql.decrypt(sk, [a3, b3, c3]) self.assertEqual(decrypted, 123 + 456 + 789) + + def test_workflow_for_secure_redundancy_sum_with_multiple_nodes(self): + """ + Test secure summation workflow for a cluster that has multiple nodes. + """ + sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) + (a0, b0, c0) = nilql.encrypt(sk, 123) + (a1, b1, c1) = nilql.encrypt(sk, 456) + (a2, b2, c2) = nilql.encrypt(sk, 789) + (a3, b3, c3) = add_shamir_shares( + add_shamir_shares([a0, b0, c0], [a1, b1, c1]), + [a2, b2, c2] + ) + decrypted = nilql.decrypt(sk, [a3, b3, c3]) + self.assertEqual(decrypted, 123 + 456 + 789) From 8312af43713b1c3a996018aefcbc43f1e1fda23c Mon Sep 17 00:00:00 2001 From: Andrei Lapets Date: Sun, 16 Feb 2025 23:40:34 -0500 Subject: [PATCH 2/3] fix: adjust updated code to avoid merge conflicts with main --- src/nilql/nilql.py | 110 ++++++++++++++++++++++++++------------------- test/test_nilql.py | 8 ++-- 2 files changed, 70 insertions(+), 48 deletions(-) diff --git a/src/nilql/nilql.py b/src/nilql/nilql.py index 006e0d6..03e330d 100644 --- a/src/nilql/nilql.py +++ b/src/nilql/nilql.py @@ -3,7 +3,7 @@ replies. """ from __future__ import annotations -from typing import Union, Optional, Sequence, Tuple +from typing import Union, Optional, Sequence import doctest import base64 import secrets @@ -30,7 +30,6 @@ _SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION = 2 """Minimum number of shares required to reconstruct a Shamir secret.""" - def _hkdf_extract(salt: bytes, input_key: bytes) -> bytes: """ Extracts a pseudorandom key (PRK) using HMAC with the given salt and input key material. @@ -113,7 +112,9 @@ def _random_int( return minimum + secrets.randbelow(maximum + 1 - minimum) def _eval_at(poly, x, prime): - """Evaluates polynomial (coefficient tuple) at x.""" + """ + Evaluates polynomial (coefficient tuple) at x. + """ accum = 0 for coeff in reversed(poly): accum *= x @@ -127,11 +128,14 @@ def _shamir_secret_share( minimum_shares=_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS ): - """Generates a random shamir pool for a given secret, returns share points.""" + """ + Generates a random shamir pool for a given secret, returns share points. + """ if minimum_shares > total_shares: raise ValueError("Pool secret would be irrecoverable.") + poly = [secret] + [secrets.randbelow(prime - 1) for _ in range(minimum_shares - 1)] - points = [(i, _eval_at(poly, i, prime)) for i in range(1, total_shares + 1)] + points = [[i, _eval_at(poly, i, prime)] for i in range(1, total_shares + 1)] return points def _extended_gcd(a, b): @@ -146,14 +150,18 @@ def _extended_gcd(a, b): return last_x, last_y def _divmod(num, den, p): - """Compute num / den modulo prime p.""" + """ + Compute num / den modulo prime p. + """ inv, _ = _extended_gcd(den, p) return num * inv % p def _lagrange_interpolate(x, x_s, y_s, p): - """Find the y-value for the given x using Lagrange interpolation.""" + """ + Find the y-value for the given x using Lagrange interpolation. + """ k = len(x_s) - assert k == len(set(x_s)), "Points must be distinct" + assert k == len(set(x_s)), "points must be distinct" def _multiply(vals): accum = 1 @@ -173,18 +181,28 @@ def _multiply(vals): return (_divmod(num, den, p) + p) % p def _recover_shamir_secret(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): - """Recover the secret from share points.""" + """ + Recover the secret from share points. + """ if len(shares) < _SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION: - raise ValueError(f"Need at least {_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION} shares") - x_s, y_s = zip(*shares) - return _lagrange_interpolate(0, x_s, y_s, prime) + raise ValueError( + f'need at least {_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION} shares' + ) + + return _lagrange_interpolate(0, *zip(*shares), prime) def add_shamir_shares(shares1, shares2, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): - """Adds two sets of shares pointwise, assuming they use the same x-values.""" + """ + Adds two sets of shares pointwise, assuming they use the same x-values. + """ if len(shares1) != len(shares2): - raise ValueError("Shares sets must have the same length") - added_shares = [(x1, (y1 + y2) % prime) for (x1, y1), (x2, y2) in zip(shares1, shares2) if x1 == x2] - return added_shares + raise ValueError('shares sets must have the same length') + + return [ + [x1, (y1 + y2) % prime] + for (x1, y1), (x2, y2) in zip(shares1, shares2) + if x1 == x2 + ] def _pack(b: bytes) -> str: """ @@ -360,6 +378,7 @@ def generate( raise RuntimeError( 'Redundancy is not supported for single-node clusters' ) + # Distinct multiplicative mask for each additive share. secret_key['material'] = [ _random_int( 1, @@ -612,7 +631,7 @@ def load(dictionary: PublicKey) -> dict: def encrypt( key: Union[SecretKey, PublicKey], plaintext: Union[int, str, bytes] - ) -> Union[str, Sequence[str], Sequence[int], Sequence[Tuple[int, int]]]: + ) -> Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]]: """ Return the ciphertext obtained by using the supplied key to encrypt the supplied plaintext. @@ -631,8 +650,8 @@ def encrypt( ): raise ValueError('numeric plaintext must be a valid 32-bit signed integer') buffer = _encode(plaintext) - elif ('sum' in key['operations'] or - 'redundancy' in key['operations']): # Non-integer cannot be encrypted for summation. + elif ('sum' in key['operations'] or 'redundancy' in key['operations']): + # Non-integer cannot be encrypted for summation. raise ValueError('numeric plaintext must be a valid 32-bit signed integer') # Encode string or binary data for storage or matching. @@ -720,23 +739,20 @@ def encrypt( return shares if key['operations'].get('redundancy'): - quantity = len(key['cluster']['nodes']) - if quantity == 1: - raise RuntimeError( - 'Redundancy is not supported for single-node clusters' - ) - # Use Shamir secret sharing for multiple-node clusters. - shamir_shares = _shamir_secret_share(plaintext, quantity) + if len(key['cluster']['nodes']) == 1: + raise RuntimeError('redundancy is not supported for single-node clusters') + + # Use Shamir's secret sharing for multiple-node clusters. masks = [ key['material'][i] if 'material' in key else 1 - for i in range(quantity) - ] - shares = [ - (x, int((y * mask) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS)) - for ((x, y), mask) in zip(shamir_shares, masks) + for i in range(len(key['cluster']['nodes'])) ] - return shares + num_nodes = len(key['cluster']['nodes']) + shares = _shamir_secret_share(plaintext, num_nodes) + for (i, share) in enumerate(shares): + share[1] = (masks[i] * share[1]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS + return shares # The below should not occur unless the key's cluster or operations # information is malformed/missing or the plaintext is unsupported. raise ValueError( @@ -745,7 +761,7 @@ def encrypt( def decrypt( key: SecretKey, - ciphertext: Union[str, Sequence[str], Sequence[int], Sequence[Tuple[int, int]]] + ciphertext: Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]] ) -> Union[int, str, bytes]: """ Return the plaintext obtained by using the supplied key to decrypt the @@ -812,8 +828,11 @@ def decrypt( (not isinstance(ciphertext, Sequence)) or (not ( all( - isinstance(c, tuple) and len(c) == 2 and - all(isinstance(x, int) for x in c) + ( + isinstance(c, Sequence) and + len(c) == 2 and + all(isinstance(x, int) for x in c) + ) for c in ciphertext ) or all(isinstance(c, int) for c in ciphertext) or @@ -897,12 +916,12 @@ def decrypt( return plaintext + # Decrypt a value that was encrypted in a summation-compatible way. if key['operations'].get('redundancy'): if len(key['cluster']['nodes']) == 1: - raise RuntimeError( - 'Redundancy is not supported for single-node clusters' - ) + raise RuntimeError('redundancy is not supported for single-node clusters') + # For multiple-node clusters, additive secret sharing is used. inverse_masks = [ pow( key['material'][i] if 'material' in key else 1, @@ -911,14 +930,15 @@ def decrypt( ) for i in range(len(key['cluster']['nodes'])) ] - unmasked_shares = [ - (x, (y * inverse_masks[x - 1]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS) - for (x, y) in ciphertext - ] - - # Use Shamir secret sharing for multiple-node clusters. - plaintext = _recover_shamir_secret(unmasked_shares) + shares = ciphertext + for (i, share) in enumerate(shares): + share[1] = ( + inverse_masks[share[0] - 1] * shares[i][1] + ) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS + plaintext = _recover_shamir_secret(shares) + # Field elements in the "upper half" of the field represent negative + # integers. if plaintext > _PLAINTEXT_SIGNED_INTEGER_MAX: plaintext -= _SECRET_SHARED_SIGNED_INTEGER_MODULUS diff --git a/test/test_nilql.py b/test/test_nilql.py index 1cac9c0..af0106e 100644 --- a/test/test_nilql.py +++ b/test/test_nilql.py @@ -14,6 +14,8 @@ from src.nilql.nilql import add_shamir_shares +_SECRET_SHARED_SIGNED_INTEGER_MODULUS = (2 ** 32) + 15 + def to_hash_base64(output: Union[bytes, list[int]]) -> str: """ Helper function for converting a large output from a test into a @@ -136,12 +138,12 @@ def test_key_operations_for_redundancy_with_multiple_nodes(self): sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'redundancy': True}) sk_loaded = nilql.SecretKey.load(sk.dump()) self.assertTrue(isinstance(sk, nilql.SecretKey)) - self.assertTrue(sk == sk_loaded) + self.assertEqual(sk_loaded, sk) sk_from_json = nilql.SecretKey.load( json.loads(json.dumps(sk.dump())) ) - self.assertTrue(sk == sk_from_json) + self.assertEqual(sk_from_json, sk) def test_key_from_seed_for_store_with_single_node(self): """ @@ -387,7 +389,7 @@ def test_ciphertext_representation_for_redundancy_with_multiple_nodes(self): operations = {'redundancy': True} ck = nilql.ClusterKey.generate(cluster, operations) plaintext = 123 - ciphertext = [(1, 1382717699), (2, 2765435275), (3, 4148152851)] + ciphertext = [[1, 1382717699], [2, 2765435275], [3, 4148152851]] decrypted = nilql.decrypt(ck, ciphertext) self.assertEqual(decrypted, plaintext) From af4495313034c4d9d9f3e3b529a4afb258a10982 Mon Sep 17 00:00:00 2001 From: Andrei Lapets Date: Sun, 16 Feb 2025 23:53:41 -0500 Subject: [PATCH 3/3] refactor: use dependencies and adjust new function names --- pyproject.toml | 1 + src/nilql/nilql.py | 69 ++++++++++------------------------------------ test/test_nilql.py | 17 +++++++++--- 3 files changed, 28 insertions(+), 59 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9abef4a..edfb462 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ license = {text = "MIT"} readme = "README.rst" requires-python = ">=3.9" dependencies = [ + "lagrange~=3.0", "bcl~=2.3", "pailliers~=0.1" ] diff --git a/src/nilql/nilql.py b/src/nilql/nilql.py index 03e330d..f810b2d 100644 --- a/src/nilql/nilql.py +++ b/src/nilql/nilql.py @@ -9,6 +9,7 @@ import secrets import hashlib import hmac +from lagrange import lagrange import bcl import pailliers @@ -27,7 +28,7 @@ _HASH = hashlib.sha512 """Hash function used for HKDF and matching.""" -_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION = 2 +_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION = 2 """Minimum number of shares required to reconstruct a Shamir secret.""" def _hkdf_extract(salt: bytes, input_key: bytes) -> bytes: @@ -111,7 +112,7 @@ def _random_int( return minimum + secrets.randbelow(maximum + 1 - minimum) -def _eval_at(poly, x, prime): +def _shamirs_eval(poly, x, prime): """ Evaluates polynomial (coefficient tuple) at x. """ @@ -122,76 +123,34 @@ def _eval_at(poly, x, prime): accum %= prime return accum -def _shamir_secret_share( +def _shamirs_shares( secret, total_shares, - minimum_shares=_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION, + minimum_shares=_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS ): """ - Generates a random shamir pool for a given secret, returns share points. + Generates a random Shamir pool for a given secret and returns share points. """ if minimum_shares > total_shares: raise ValueError("Pool secret would be irrecoverable.") poly = [secret] + [secrets.randbelow(prime - 1) for _ in range(minimum_shares - 1)] - points = [[i, _eval_at(poly, i, prime)] for i in range(1, total_shares + 1)] + points = [[i, _shamirs_eval(poly, i, prime)] for i in range(1, total_shares + 1)] return points -def _extended_gcd(a, b): - """Extended Euclidean algorithm for modular inverse.""" - x, last_x = 0, 1 - y, last_y = 1, 0 - while b != 0: - quot = a // b - a, b = b, a % b - x, last_x = last_x - quot * x, x - y, last_y = last_y - quot * y, y - return last_x, last_y - -def _divmod(num, den, p): - """ - Compute num / den modulo prime p. - """ - inv, _ = _extended_gcd(den, p) - return num * inv % p - -def _lagrange_interpolate(x, x_s, y_s, p): - """ - Find the y-value for the given x using Lagrange interpolation. - """ - k = len(x_s) - assert k == len(set(x_s)), "points must be distinct" - - def _multiply(vals): - accum = 1 - for v in vals: - accum *= v - return accum - - nums, dens = [], [] - for i in range(k): - others = list(x_s) - cur = others.pop(i) - nums.append(_multiply(x - o for o in others)) - dens.append(_multiply(cur - o for o in others)) - - den = _multiply(dens) - num = sum(_divmod(nums[i] * den * y_s[i] % p, dens[i], p) for i in range(k)) - return (_divmod(num, den, p) + p) % p - -def _recover_shamir_secret(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): +def _shamirs_recover(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): """ Recover the secret from share points. """ - if len(shares) < _SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION: + if len(shares) < _SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION: raise ValueError( - f'need at least {_SHAMIR_MINIMUM_SHARES_FOR_RECONSTRUCTION} shares' + f'need at least {_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION} shares' ) - return _lagrange_interpolate(0, *zip(*shares), prime) + return lagrange(shares, prime) -def add_shamir_shares(shares1, shares2, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): +def _shamirs_add(shares1, shares2, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): """ Adds two sets of shares pointwise, assuming they use the same x-values. """ @@ -748,7 +707,7 @@ def encrypt( for i in range(len(key['cluster']['nodes'])) ] num_nodes = len(key['cluster']['nodes']) - shares = _shamir_secret_share(plaintext, num_nodes) + shares = _shamirs_shares(plaintext, num_nodes) for (i, share) in enumerate(shares): share[1] = (masks[i] * share[1]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS @@ -935,7 +894,7 @@ def decrypt( share[1] = ( inverse_masks[share[0] - 1] * shares[i][1] ) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS - plaintext = _recover_shamir_secret(shares) + plaintext = _shamirs_recover(shares) # Field elements in the "upper half" of the field represent negative # integers. diff --git a/test/test_nilql.py b/test/test_nilql.py index af0106e..08d29d6 100644 --- a/test/test_nilql.py +++ b/test/test_nilql.py @@ -12,10 +12,19 @@ import nilql -from src.nilql.nilql import add_shamir_shares - _SECRET_SHARED_SIGNED_INTEGER_MODULUS = (2 ** 32) + 15 + +def _shamirs_add(shares1, shares2, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS): + """ + Adds two sets of shares pointwise, assuming they use the same x-values. + """ + return [ + [x1, (y1 + y2) % prime] + for (x1, y1), (x2, y2) in zip(shares1, shares2) + if x1 == x2 + ] + def to_hash_base64(output: Union[bytes, list[int]]) -> str: """ Helper function for converting a large output from a test into a @@ -544,8 +553,8 @@ def test_workflow_for_secure_redundancy_sum_with_multiple_nodes(self): (a0, b0, c0) = nilql.encrypt(sk, 123) (a1, b1, c1) = nilql.encrypt(sk, 456) (a2, b2, c2) = nilql.encrypt(sk, 789) - (a3, b3, c3) = add_shamir_shares( - add_shamir_shares([a0, b0, c0], [a1, b1, c1]), + (a3, b3, c3) = _shamirs_add( + _shamirs_add([a0, b0, c0], [a1, b1, c1]), [a2, b2, c2] ) decrypted = nilql.decrypt(sk, [a3, b3, c3])