Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ This library provides cryptographic operations that are compatible with nilDB no
| +------------+------------------------------------+------------------------------------+
| | sum | | additive secret sharing | 32-bit signed integer |
| | | | (prime modulus 2^32 + 15) | |
| +------------+------------------------------------+------------------------------------+
| | redundancy | | Shamir secret sharing | 32-bit signed integer |
| | | | (prime modulus 2^32 + 15) | |
+-------------+------------+------------------------------------+------------------------------------+

Installation and Usage
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ license = {text = "MIT"}
readme = "README.rst"
requires-python = ">=3.9"
dependencies = [
"lagrange~=3.0",
"bcl~=2.3",
"pailliers~=0.1"
]
Expand Down
155 changes: 149 additions & 6 deletions src/nilql/nilql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import secrets
import hashlib
import hmac
from lagrange import lagrange
import bcl
import pailliers

Expand All @@ -27,6 +28,9 @@
_HASH = hashlib.sha512
"""Hash function used for HKDF and matching."""

_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION = 2
"""Minimum number of shares required to reconstruct a Shamir secret."""

def _hkdf_extract(salt: bytes, input_key: bytes) -> bytes:
"""
Extracts a pseudorandom key (PRK) using HMAC with the given salt and input key material.
Expand Down Expand Up @@ -108,6 +112,57 @@ def _random_int(

return minimum + secrets.randbelow(maximum + 1 - minimum)

def _shamirs_eval(poly, x, prime):
"""
Evaluates polynomial (coefficient tuple) at x.
"""
accum = 0
for coeff in reversed(poly):
accum *= x
accum += coeff
accum %= prime
return accum

def _shamirs_shares(
secret,
total_shares,
minimum_shares=_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION,
prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS
):
"""
Generates a random Shamir pool for a given secret and returns share points.
"""
if minimum_shares > total_shares:
raise ValueError("Pool secret would be irrecoverable.")

poly = [secret] + [secrets.randbelow(prime - 1) for _ in range(minimum_shares - 1)]
points = [[i, _shamirs_eval(poly, i, prime)] for i in range(1, total_shares + 1)]
return points

def _shamirs_recover(shares, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS):
"""
Recover the secret from share points.
"""
if len(shares) < _SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION:
raise ValueError(
f'need at least {_SHAMIRS_MINIMUM_SHARES_FOR_RECONSTRUCTION} shares'
)

return lagrange(shares, prime)

def _shamirs_add(shares1, shares2, prime=_SECRET_SHARED_SIGNED_INTEGER_MODULUS):
"""
Adds two sets of shares pointwise, assuming they use the same x-values.
"""
if len(shares1) != len(shares2):
raise ValueError('shares sets must have the same length')

return [
[x1, (y1 + y2) % prime]
for (x1, y1), (x2, y2) in zip(shares1, shares2)
if x1 == x2
]

def _pack(b: bytes) -> str:
"""
Encode a bytes-like object as a Base64 string (for compatibility with JSON).
Expand Down Expand Up @@ -234,7 +289,7 @@ def generate(

if (
(not isinstance(operations, dict)) or
(not set(operations.keys()).issubset({'store', 'match', 'sum'}))
(not set(operations.keys()).issubset({'store', 'match', 'sum', 'redundancy'}))
):
raise ValueError('valid operations specification is required')

Expand All @@ -261,7 +316,7 @@ def generate(
'seed-based derivation of summation-compatible keys ' +
'is not supported for single-node clusters'
)
secret_key['material'] = pailliers.secret(2048)
secret_key['material'] = pailliers.secret(256)
else:
# Distinct multiplicative mask for each additive share.
secret_key['material'] = [
Expand All @@ -277,6 +332,25 @@ def generate(
for i in range(len(secret_key['cluster']['nodes']))
]

if secret_key['operations'].get('redundancy'):
if len(secret_key['cluster']['nodes']) == 1:
raise RuntimeError(
'Redundancy is not supported for single-node clusters'
)
# Distinct multiplicative mask for each additive share.
secret_key['material'] = [
_random_int(
1,
_SECRET_SHARED_SIGNED_INTEGER_MODULUS - 1,
(
_random_bytes(64, seed, i.to_bytes(64, 'little'))
if seed is not None else
None
)
)
for i in range(len(secret_key['cluster']['nodes']))
]

return secret_key

def dump(self: SecretKey) -> dict:
Expand Down Expand Up @@ -390,6 +464,15 @@ def generate( # pylint: disable=arguments-differ # Seeds not supported.
# Cluster keys contain no cryptographic material.
if 'material' in cluster_key:
del cluster_key['material']
# =======
# # Ensure that the secret key material is the identity value
# # for the supported operation.
# if len(cluster_key['cluster']['nodes']) > 1:
# if cluster_key['operations'].get('store'):
# cluster_key['material'] = bytes(_PLAINTEXT_STRING_BUFFER_LEN_MAX)
# if cluster_key['operations'].get('sum') or cluster_key['operations'].get('redundancy'):
# cluster_key['material'] = 1
# >>>>>>> d7c678d (feat: Add Shamir secret sharing)

return cluster_key

Expand Down Expand Up @@ -507,7 +590,7 @@ def load(dictionary: PublicKey) -> dict:
def encrypt(
key: Union[SecretKey, PublicKey],
plaintext: Union[int, str, bytes]
) -> Union[str, Sequence[str], Sequence[int]]:
) -> Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]]:
"""
Return the ciphertext obtained by using the supplied key to encrypt the
supplied plaintext.
Expand All @@ -526,7 +609,7 @@ def encrypt(
):
raise ValueError('numeric plaintext must be a valid 32-bit signed integer')
buffer = _encode(plaintext)
elif 'sum' in key['operations']:
elif ('sum' in key['operations'] or 'redundancy' in key['operations']):
# Non-integer cannot be encrypted for summation.
raise ValueError('numeric plaintext must be a valid 32-bit signed integer')

Expand Down Expand Up @@ -614,6 +697,21 @@ def encrypt(
)
return shares

if key['operations'].get('redundancy'):
if len(key['cluster']['nodes']) == 1:
raise RuntimeError('redundancy is not supported for single-node clusters')

# Use Shamir's secret sharing for multiple-node clusters.
masks = [
key['material'][i] if 'material' in key else 1
for i in range(len(key['cluster']['nodes']))
]
num_nodes = len(key['cluster']['nodes'])
shares = _shamirs_shares(plaintext, num_nodes)
for (i, share) in enumerate(shares):
share[1] = (masks[i] * share[1]) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS

return shares
# The below should not occur unless the key's cluster or operations
# information is malformed/missing or the plaintext is unsupported.
raise ValueError(
Expand All @@ -622,7 +720,7 @@ def encrypt(

def decrypt(
key: SecretKey,
ciphertext: Union[str, Sequence[str], Sequence[int]]
ciphertext: Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]]
) -> Union[int, str, bytes]:
"""
Return the plaintext obtained by using the supplied key to decrypt the
Expand All @@ -649,6 +747,12 @@ def decrypt(
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'sum': True})
>>> decrypt(key, encrypt(key, -10))
-10
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'redundancy': True})
>>> decrypt(key, encrypt(key, 123))
123
>>> key = SecretKey.generate({'nodes': [{}, {}]}, {'redundancy': True})
>>> decrypt(key, encrypt(key, -10))
-10

An exception is raised if a ciphertext cannot be decrypted using the
supplied key (*e.g.*, because one or both are malformed or they are
Expand Down Expand Up @@ -682,6 +786,14 @@ def decrypt(
if (
(not isinstance(ciphertext, Sequence)) or
(not (
all(
(
isinstance(c, Sequence) and
len(c) == 2 and
all(isinstance(x, int) for x in c)
)
for c in ciphertext
) or
all(isinstance(c, int) for c in ciphertext) or
all(isinstance(c, str) for c in ciphertext)
))
Expand All @@ -690,7 +802,10 @@ def decrypt(
'secret key requires a valid ciphertext from a multiple-node cluster'
)

if len(key['cluster']['nodes']) != len(ciphertext):
if (
isinstance(ciphertext, Sequence) and
len(key['cluster']['nodes']) != len(ciphertext)
) and not key['operations'].get('redundancy'):
raise ValueError(
'secret key and ciphertext must have the same associated cluster size'
)
Expand Down Expand Up @@ -760,6 +875,34 @@ def decrypt(

return plaintext

# Decrypt a value that was encrypted in a summation-compatible way.
if key['operations'].get('redundancy'):
if len(key['cluster']['nodes']) == 1:
raise RuntimeError('redundancy is not supported for single-node clusters')

# For multiple-node clusters, additive secret sharing is used.
inverse_masks = [
pow(
key['material'][i] if 'material' in key else 1,
_SECRET_SHARED_SIGNED_INTEGER_MODULUS - 2,
_SECRET_SHARED_SIGNED_INTEGER_MODULUS
)
for i in range(len(key['cluster']['nodes']))
]
shares = ciphertext
for (i, share) in enumerate(shares):
share[1] = (
inverse_masks[share[0] - 1] * shares[i][1]
) % _SECRET_SHARED_SIGNED_INTEGER_MODULUS
plaintext = _shamirs_recover(shares)

# Field elements in the "upper half" of the field represent negative
# integers.
if plaintext > _PLAINTEXT_SIGNED_INTEGER_MAX:
plaintext -= _SECRET_SHARED_SIGNED_INTEGER_MODULUS

return plaintext

raise error

def allot(
Expand Down
Loading