From eb17eb38964ae02fe30e95e81cd15083b80b660b Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Mon, 10 Feb 2025 16:02:11 -0500 Subject: [PATCH] feat: replace _seeds with HKDF. --- src/nilql/nilql.py | 50 +++++++++++++++++++++++++++++++--------------- test/test_nilql.py | 20 +++++++++---------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/nilql/nilql.py b/src/nilql/nilql.py index b36bccc..653cc82 100644 --- a/src/nilql/nilql.py +++ b/src/nilql/nilql.py @@ -8,6 +8,7 @@ import base64 import secrets import hashlib +import hmac import bcl import pailliers @@ -23,26 +24,46 @@ _PLAINTEXT_STRING_BUFFER_LEN_MAX = 4096 """Maximum length of plaintext string values that can be encrypted.""" -def _seeds(seed: bytes, index: int) -> bytes: +_Hash = hashlib.sha512 +"""Hash function used for HKDF and matching.""" + + +def _hkdf_extract(salt: bytes, input_key: bytes) -> bytes: """ - Generate entries in an indexed sequence of seeds derived from a base seed. + Extracts a pseudorandom key (PRK) using HMAC with the given salt and input key material. + If the salt is empty, a zero-filled byte string of the same length as the hash function's digest size is used. """ - if index < 0 or index >= 2 ** 64: - raise ValueError('index must be a 64-bit unsigned integer value') + if len(salt) == 0: + salt = bytes([0] * _Hash().digest_size) + return hmac.new(salt, input_key, _Hash).digest() - return hashlib.sha512(seed + index.to_bytes(8, 'little')).digest() +def _hkdf_expand(pseudo_random_key: bytes, info: bytes, length: int) -> bytes: + """ + Expands the pseudo_random_key into an output key material (OKM) of the desired length using HMAC-based expansion. + """ + t = b"" + okm = b"" + i = 0 + while len(okm) < length: + i += 1 + t = hmac.new(pseudo_random_key, t + info + bytes([i]), _Hash).digest() + okm += t + return okm[:length] + +def _hkdf(length: int, input_key: bytes, salt: bytes = b"", info: bytes = b"") -> bytes: + """ + Extract a pseudorandom key of `length` from `input_key` and optionally `salt` and `info`. + """ + prk = _hkdf_extract(salt, input_key) + return _hkdf_expand(prk, info, length) -def _random_bytes(length: int, seed: Optional[bytes] = None) -> bytes: +def _random_bytes(length: int, seed: Optional[bytes] = None, salt: Optional[bytes] = None) -> bytes: """ Return a random :obj:`bytes` value of the specified length (using the seed if one is supplied). """ if seed is not None: - bytes_ = bytes() - iterations = (length // 64) + (1 if length % 64 > 0 else 0) - for i in range(iterations): - bytes_ = bytes_ + _seeds(seed, i) - return bytes_[:length] + return _hkdf(length, seed, b"" if salt is None else salt) return secrets.token_bytes(length) @@ -72,10 +93,7 @@ def _random_int( integer = None index = 0 while integer is None or integer > range_: - bytes_ = bytearray(_random_bytes( - 8, - None if seed is None else _seeds(seed, index) - )) + bytes_ = bytearray(_random_bytes(8, seed, index.to_bytes(8, 'little'))) index += 1 bytes_[4] &= 1 bytes_[5] &= 0 @@ -514,7 +532,7 @@ def encrypt( # Encrypt (i.e., hash) a value for matching. if key['operations'].get('match'): - ciphertext = _pack(hashlib.sha512(key['material'] + buffer).digest()) + ciphertext = _pack(_Hash(key['material'] + buffer).digest()) # If there are multiple nodes, prepare the same ciphertext for each. if len(key['cluster']['nodes']) > 1: diff --git a/test/test_nilql.py b/test/test_nilql.py index 9315d4c..cd4d4d9 100644 --- a/test/test_nilql.py +++ b/test/test_nilql.py @@ -129,12 +129,12 @@ def test_key_from_seed_for_store_with_single_node(self): sk_from_seed = nilql.SecretKey.generate({'nodes': [{}]}, {'store': True}, SEED) self.assertEqual( to_hash_base64(sk_from_seed['material']), - 'TVFhJJ32+eh+yaYL1Dhcw7Z+ykY4N1cKDJXDxdS92vI=' + '2bW6BLeeCTqsCqrijSkBBPGjDb/gzjtGnFZt0nsZP8w=' ) sk = nilql.SecretKey.generate({'nodes': [{}]}, {'store': True}) self.assertNotEqual( to_hash_base64(sk['material']), - 'TVFhJJ32+eh+yaYL1Dhcw7Z+ykY4N1cKDJXDxdS92vI=' + '2bW6BLeeCTqsCqrijSkBBPGjDb/gzjtGnFZt0nsZP8w=' ) def test_key_from_seed_for_store_with_multiple_nodes(self): @@ -144,12 +144,12 @@ def test_key_from_seed_for_store_with_multiple_nodes(self): sk_from_seed = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'store': True}, SEED) self.assertEqual( to_hash_base64(sk_from_seed['material']), - 'i4ZP5syVY2V6ZFboTey/S83j+7ufgrs4/kUB849/uAI=' + 'UEoI836rNUBdCixoavnwlPEVqAe2wrPxj+UkVpJPPo0=' ) sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'store': True}) self.assertNotEqual( to_hash_base64(sk['material']), - 'i4ZP5syVY2V6ZFboTey/S83j+7ufgrs4/kUB849/uAI=' + 'UEoI836rNUBdCixoavnwlPEVqAe2wrPxj+UkVpJPPo0=' ) def test_key_from_seed_for_match_with_single_node(self): @@ -159,12 +159,12 @@ def test_key_from_seed_for_match_with_single_node(self): sk_from_seed = nilql.SecretKey.generate({'nodes': [{}]}, {'match': True}, SEED) self.assertEqual( to_hash_base64(sk_from_seed['material']), - 'M4qqWosTwaBvPMEvUDWKg/RJA3+18+mv/X5Zlj21NhY=' + 'qbcFGTOGTPo+vs3EChnVUWk5lnn6L6Cr/DIq8li4H+4=' ) sk = nilql.SecretKey.generate({'nodes': [{}]}, {'match': True}) self.assertNotEqual( to_hash_base64(sk['material']), - 'M4qqWosTwaBvPMEvUDWKg/RJA3+18+mv/X5Zlj21NhY=' + 'qbcFGTOGTPo+vs3EChnVUWk5lnn6L6Cr/DIq8li4H+4=' ) def test_key_from_seed_for_match_with_multiple_nodes(self): @@ -174,12 +174,12 @@ def test_key_from_seed_for_match_with_multiple_nodes(self): sk_from_seed = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'match': True}, SEED) self.assertEqual( to_hash_base64(sk_from_seed['material']), - 'M4qqWosTwaBvPMEvUDWKg/RJA3+18+mv/X5Zlj21NhY=' + 'qbcFGTOGTPo+vs3EChnVUWk5lnn6L6Cr/DIq8li4H+4=' ) sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'match': True}) self.assertNotEqual( to_hash_base64(sk['material']), - 'M4qqWosTwaBvPMEvUDWKg/RJA3+18+mv/X5Zlj21NhY=' + 'qbcFGTOGTPo+vs3EChnVUWk5lnn6L6Cr/DIq8li4H+4=' ) def test_key_from_seed_for_sum_with_multiple_nodes(self): @@ -189,12 +189,12 @@ def test_key_from_seed_for_sum_with_multiple_nodes(self): sk_from_seed = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'sum': True}, SEED) self.assertEqual( to_hash_base64(sk_from_seed['material']), - 'voydliW+MzaYaaIs6ydwLyZdNyYclj+APB2BxNK+AKY=' + 'l3O25x9CYiiA+XXTNPoT4WylTOXjeWj4GmoSoOPpZHo=' ) sk = nilql.SecretKey.generate({'nodes': [{}, {}, {}]}, {'sum': True}) self.assertNotEqual( to_hash_base64(sk['material']), - 'voydliW+MzaYaaIs6ydwLyZdNyYclj+APB2BxNK+AKY=' + 'l3O25x9CYiiA+XXTNPoT4WylTOXjeWj4GmoSoOPpZHo=' ) class TestKeysError(TestCase):