From 2a95278a50a6d2dae2cc3298e50f1e6f5c55192b Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 01:27:15 -0700 Subject: [PATCH 01/35] Create new file called ChaCha20.py --- crypto/ChaCha20.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 crypto/ChaCha20.py diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py new file mode 100644 index 000000000..e69de29bb From 3b20043e8e30a2347a9e3f9f5dac2d58191724a9 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 01:30:20 -0700 Subject: [PATCH 02/35] Add docstring for the ChaCha20 algorithm --- crypto/ChaCha20.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index e69de29bb..c28ce4ed8 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -0,0 +1,17 @@ +from typing import List +import struct + +__all__ = ['ChaCha20'] +class ChaCha20: + """ + Implementation of the ChaCha20 stream cipher. + + Attributes + ---------- + key : bytes + 32-byte (256-bit) encryption key. + nonce : bytes + 12-byte (96-bit) nonce. + counter : int + 32-bit counter, typically starts at 0. + """ \ No newline at end of file From 4b663fbcceab6d03d0a8303942ce9dd8b57fa853 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 01:39:35 -0700 Subject: [PATCH 03/35] Implement __new__() constructor with key (32 bytes) and nonce (12 bytes) length checks --- crypto/ChaCha20.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index c28ce4ed8..30d2f235a 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -14,4 +14,14 @@ class ChaCha20: 12-byte (96-bit) nonce. counter : int 32-bit counter, typically starts at 0. - """ \ No newline at end of file + """ + def __new__(cls, key: bytes, nonce: bytes, counter: int = 0): + if not isinstance(key, bytes) or len(key) != 32: + raise ValueError("Key must be exactly 32 bytes (256 bits).") + if not isinstance(nonce, bytes) or len(nonce) != 12: + raise ValueError("Nonce must be exactly 12 bytes (96 bits).") + instance = super().__new__(cls) + instance.key = key + instance.nonce = nonce + instance.counter = counter + return instance \ No newline at end of file From c69a13a2acd50f377ae08b3dc90d411a3ce6021a Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 01:43:08 -0700 Subject: [PATCH 04/35] Add quarter-round function --- crypto/ChaCha20.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index 30d2f235a..765ae4bc4 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -24,4 +24,17 @@ def __new__(cls, key: bytes, nonce: bytes, counter: int = 0): instance.key = key instance.nonce = nonce instance.counter = counter - return instance \ No newline at end of file + return instance + def _quarter_round(self, state: List[int], a: int, b: int, c: int, d: int): + state[a] = (state[a] + state[b]) % (2**32) + state[d] ^= state[a] + state[d] = ((state[d] << 16) | (state[d] >> 16)) % (2**32) + state[c] = (state[c] + state[d]) % (2**32) + state[b] ^= state[c] + state[b] = ((state[b] << 12) | (state[b] >> 20)) % (2**32) + state[a] = (state[a] + state[b]) % (2**32) + state[d] ^= state[a] + state[d] = ((state[d] << 8) | (state[d] >> 24)) % (2**32) + state[c] = (state[c] + state[d]) % (2**32) + state[b] ^= state[c] + state[b] = ((state[b] << 7) | (state[b] >> 25)) % (2**32) From 2e0f28bbfd576f3fea711771f98dca169b33ed1d Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 01:46:03 -0700 Subject: [PATCH 05/35] Add double-round function --- crypto/ChaCha20.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index 765ae4bc4..d196af11f 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -38,3 +38,12 @@ def _quarter_round(self, state: List[int], a: int, b: int, c: int, d: int): state[c] = (state[c] + state[d]) % (2**32) state[b] ^= state[c] state[b] = ((state[b] << 7) | (state[b] >> 25)) % (2**32) + def _double_round(self, state: List[int]): + self._quarter_round(state, 0, 4, 8, 12) + self._quarter_round(state, 1, 5, 9, 13) + self._quarter_round(state, 2, 6, 10, 14) + self._quarter_round(state, 3, 7, 11, 15) + self._quarter_round(state, 0, 5, 10, 15) + self._quarter_round(state, 1, 6, 11, 12) + self._quarter_round(state, 2, 7, 8, 13) + self._quarter_round(state, 3, 4, 9, 14) \ No newline at end of file From 67e7e89a02e508c06f546f0211d87cb711438374 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 02:18:26 -0700 Subject: [PATCH 06/35] Add function describing ChaCha20 initial state --- crypto/ChaCha20.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index d196af11f..b64bbffc2 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -1,6 +1,7 @@ from typing import List import struct - +import numpy as np +from copy import deepcopy as dp __all__ = ['ChaCha20'] class ChaCha20: """ @@ -46,4 +47,23 @@ def _double_round(self, state: List[int]): self._quarter_round(state, 0, 5, 10, 15) self._quarter_round(state, 1, 6, 11, 12) self._quarter_round(state, 2, 7, 8, 13) - self._quarter_round(state, 3, 4, 9, 14) \ No newline at end of file + self._quarter_round(state, 3, 4, 9, 14) + def _chacha20_block(self, counter: int) -> bytes: + """ + Generates a 64-byte keystream block from 16-word (512-bit) state + The initial state is copied to preserve the original. + 20 rounds (10 double rounds) are performed using quarter-round operations. + The modified working state is combined with the original state using modular addition (mod 2^32). + The result is returned as a 64-byte keystream block. + """ + constants = b"expand 32-byte k" + state_values = struct.unpack( + '<16I', + constants + self.key + struct.pack(' Date: Tue, 18 Feb 2025 11:48:08 -0700 Subject: [PATCH 07/35] Implement apply_keystream() method for ChaCha20 XOR operation --- crypto/ChaCha20.py | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index b64bbffc2..e0d22b9ff 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -66,4 +66,40 @@ def _chacha20_block(self, counter: int) -> bytes: for _ in range(10): self._double_round(working_state) final_state = (working_state + state) % (2**32) - return struct.pack('<16I', *final_state.flatten()) \ No newline at end of file + return struct.pack('<16I', *final_state.flatten()) + + def _apply_keystream(self, data: bytes) -> bytes: + """ + Applies the ChaCha20 keystream to the input data (plaintext or ciphertext) + to perform encryption or decryption. + + This method processes the input data in 64-byte blocks. For each block: + - A 64-byte keystream is generated using the `_chacha20_block()` function. + - Each byte of the input block is XORed with the corresponding keystream byte. + - The XORed result is appended to the output. + + The same function is used for both encryption and decryption because + XORing the ciphertext with the same keystream returns the original plaintext. + + Args: + data (bytes): The input data to be encrypted or decrypted (plaintext or ciphertext). + + Returns: + bytes: The result of XORing the input data with the ChaCha20 keystream + (ciphertext if plaintext was provided, plaintext if ciphertext was provided). + """ + result = b"" + chunk_size = 64 + start = 0 + while start < len(data): + chunk = data[start:start + chunk_size] + start += chunk_size + keystream = self._chacha20_block(self.counter) + self.counter += 1 + xor_block = [] + for idx in range(len(chunk)): + input_byte = chunk[idx] + keystream_byte = keystream[idx] + xor_block.append(input_byte ^ keystream_byte) + result += bytes(xor_block) + return result From 95eb1e97debb63b34e45ecd2e62bdd9af0ed66c7 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 11:51:47 -0700 Subject: [PATCH 08/35] Add encrypt method using apply_keystream method --- crypto/ChaCha20.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index e0d22b9ff..1eb71679d 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -103,3 +103,19 @@ def _apply_keystream(self, data: bytes) -> bytes: xor_block.append(input_byte ^ keystream_byte) result += bytes(xor_block) return result + def encrypt(self, plaintext: bytes) -> bytes: + """ + Encrypts the given plaintext using the ChaCha20 stream cipher. + + This method uses the ChaCha20 keystream generated from the + key, nonce, and counter to XOR with the plaintext, producing ciphertext. + + Args: + plaintext (bytes): The plaintext data to be encrypted. + + Returns: + bytes: The resulting ciphertext. + """ + return self._apply_keystream(plaintext) + + \ No newline at end of file From 654e0086b0eb526510acadaf7314aa5797234ae5 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 11:53:24 -0700 Subject: [PATCH 09/35] Add decrypt method using apply_keystream method --- crypto/ChaCha20.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index 1eb71679d..41a733bde 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -118,4 +118,17 @@ def encrypt(self, plaintext: bytes) -> bytes: """ return self._apply_keystream(plaintext) - \ No newline at end of file + def decrypt(self, ciphertext: bytes) -> bytes: + """ + Decrypts the given ciphertext using the ChaCha20 stream cipher. + + Since ChaCha20 uses XOR for encryption, decryption is performed + using the same keystream and XOR operation. + + Args: + ciphertext (bytes): The ciphertext data to be decrypted. + + Returns: + bytes: The resulting plaintext. + """ + return self.apply_keystream(ciphertext) \ No newline at end of file From f85491ce5fd854e95058632567ab8bd2f9ec3b92 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 11:58:14 -0700 Subject: [PATCH 10/35] Add new file called test_chacha20.py --- crypto/tests/test_chacha20.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 crypto/tests/test_chacha20.py diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py new file mode 100644 index 000000000..e69de29bb From 62f93d91e3d2992f96bf76b212c17326df8ce8d8 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 12:56:13 -0700 Subject: [PATCH 11/35] Add explicit size assertionf for VALID_KEY and VALID_NONCE --- crypto/tests/test_chacha20.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py index e69de29bb..55f4a2fc5 100644 --- a/crypto/tests/test_chacha20.py +++ b/crypto/tests/test_chacha20.py @@ -0,0 +1,10 @@ +import random +import string +from crypto.ChaCha20 import ChaCha20 + +VALID_KEY = b"\x00" *32 +assert len(VALID_KEY) == 32, "VALID_KEY must be exactly 32 bytes" +VALID_NONCE = B"\x00" * 12 +assert len(VALID_NONCE) == 12, "VALID_NONCE must be exactly 12 bytes" + +secure_rng = random.SystemRandom() \ No newline at end of file From 3bf8840eec53999b5c8795befad33da30d12b673 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:00:05 -0700 Subject: [PATCH 12/35] Add unit test for ChaCha20 key size validation --- crypto/tests/test_chacha20.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py index 55f4a2fc5..be5608869 100644 --- a/crypto/tests/test_chacha20.py +++ b/crypto/tests/test_chacha20.py @@ -7,4 +7,20 @@ VALID_NONCE = B"\x00" * 12 assert len(VALID_NONCE) == 12, "VALID_NONCE must be exactly 12 bytes" -secure_rng = random.SystemRandom() \ No newline at end of file +secure_rng = random.SystemRandom() + +def test_invalid_key_size(): + """Test invalid key sizes.""" + try: + ChaCha20(b"short_key", VALID_NONCE) + except ValueError as e: + assert "Key must be exactly 32 bytes" in str(e) + else: + assert False, "ValueError was not raised for short key" + + try: + ChaCha20(b"A" * 33, VALID_NONCE) + except ValueError as e: + assert "Key must be exactly 32 bytes" in str(e) + else: + assert False, "ValueError was not raised for long key" From 8541ed619aa46bb49903141c3accf9c0cb74e442 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:01:54 -0700 Subject: [PATCH 13/35] Add unit test for ChaCha20 nonce size validation --- crypto/tests/test_chacha20.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py index be5608869..72bed55ef 100644 --- a/crypto/tests/test_chacha20.py +++ b/crypto/tests/test_chacha20.py @@ -24,3 +24,19 @@ def test_invalid_key_size(): assert "Key must be exactly 32 bytes" in str(e) else: assert False, "ValueError was not raised for long key" + +def test_invalid_nonce_size(): + """Test invalid nonce sizes.""" + try: + ChaCha20(VALID_KEY, b"short") + except ValueError as e: + assert "Nonce must be exactly 12 bytes" in str(e) + else: + assert False, "ValueError was not raised for short nonce" + + try: + ChaCha20(VALID_KEY, b"A" * 13) + except ValueError as e: + assert "Nonce must be exactly 12 bytes" in str(e) + else: + assert False, "ValueError was not raised for long nonce" From a2cfcbfbfe4ae8ca3055fae158602a87d2b6017f Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:05:10 -0700 Subject: [PATCH 14/35] Add unit test for negative counter validation --- crypto/tests/test_chacha20.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py index 72bed55ef..9904adc12 100644 --- a/crypto/tests/test_chacha20.py +++ b/crypto/tests/test_chacha20.py @@ -40,3 +40,13 @@ def test_invalid_nonce_size(): assert "Nonce must be exactly 12 bytes" in str(e) else: assert False, "ValueError was not raised for long nonce" + +def test_invalid_counter_values(): + """Test invalid counter values for ChaCha20.""" + for invalid_counter in [-1, -100, -999999]: + try: + ChaCha20(VALID_KEY, VALID_NONCE, counter=invalid_counter) + except ValueError as e: + assert "Counter must be a non-negative integer" in str(e) + else: + assert False, f"ValueError not raised for counter={invalid_counter}" From f261f945bd72d0c05865ea9b95a0eca930d7d90e Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:09:51 -0700 Subject: [PATCH 15/35] Add test case that verifies that ChaCha20 produces a reversible ciphertext --- crypto/tests/test_chacha20.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py index 9904adc12..098a36310 100644 --- a/crypto/tests/test_chacha20.py +++ b/crypto/tests/test_chacha20.py @@ -40,7 +40,7 @@ def test_invalid_nonce_size(): assert "Nonce must be exactly 12 bytes" in str(e) else: assert False, "ValueError was not raised for long nonce" - + def test_invalid_counter_values(): """Test invalid counter values for ChaCha20.""" for invalid_counter in [-1, -100, -999999]: @@ -50,3 +50,12 @@ def test_invalid_counter_values(): assert "Counter must be a non-negative integer" in str(e) else: assert False, f"ValueError not raised for counter={invalid_counter}" + +def test_encrypt_decrypt(): + """Test encryption and decryption are symmetric.""" + cipher = ChaCha20(VALID_KEY, VALID_NONCE) + plaintext = b"Hello, ChaCha20!" + ciphertext = cipher.encrypt(plaintext) + decrypted = cipher.decrypt(ciphertext) + + assert decrypted == plaintext, "Decryption failed. Plaintext does not match." From 9eddec4bbfd908ec623687542b91cc8da2c3960b Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:14:57 -0700 Subject: [PATCH 16/35] Handle ChaCha20 encryption and decryption for empty input --- crypto/ChaCha20.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crypto/ChaCha20.py b/crypto/ChaCha20.py index 41a733bde..aed92f444 100644 --- a/crypto/ChaCha20.py +++ b/crypto/ChaCha20.py @@ -88,6 +88,8 @@ def _apply_keystream(self, data: bytes) -> bytes: bytes: The result of XORing the input data with the ChaCha20 keystream (ciphertext if plaintext was provided, plaintext if ciphertext was provided). """ + if len(data) == 0: + return b"" result = b"" chunk_size = 64 start = 0 From c4746f40b90e421a9749da37e6b74768d143e571 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:41:23 -0700 Subject: [PATCH 17/35] Add test case ChaCha20 key reuse vulnerability --- crypto/tests/test_chacha20.py | 53 +++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py index 098a36310..0826947e9 100644 --- a/crypto/tests/test_chacha20.py +++ b/crypto/tests/test_chacha20.py @@ -59,3 +59,56 @@ def test_encrypt_decrypt(): decrypted = cipher.decrypt(ciphertext) assert decrypted == plaintext, "Decryption failed. Plaintext does not match." + +def test_key_reuse_simple(): + """ + Test the vulnerability of key reuse in ChaCha20 encryption. + + This test demonstrates the security flaw of reusing the same key and nonce + for different plaintexts in stream ciphers. It exploits the property that + XORing two ciphertexts from the same keystream cancels out the keystream, + revealing the XOR of the plaintexts. + + Encrypt two different plaintexts with the same key and nonce. + XOR the resulting ciphertexts to remove the keystream, leaving only the XOR of plaintexts. + XOR the result with the first plaintext to recover the second plaintext. + Assert that the recovered plaintext matches the original second plaintext. + + Expected Behavior: + - If the ChaCha20 implementation is correct, reusing the same key and nonce + will expose the XOR relationship between plaintexts. + - The test should successfully recover the second plaintext using XOR operations. + + Assertion: + - Raises an AssertionError if the recovered plaintext does not match the + original second plaintext, indicating a failure in the XOR recovery logic. + + Output: + - Prints the original second plaintext. + - Prints the recovered plaintext (should be identical to the original). + - Displays the XOR result (hexadecimal format) for inspection. + + Security Note: + - This test highlights why it is critical never to reuse the same key and nonce + in stream ciphers like ChaCha20. + """ + + + cipher1 = ChaCha20(VALID_KEY, VALID_NONCE) + cipher2 = ChaCha20(VALID_KEY, VALID_NONCE) + + plaintext1 = b"Hello, this is message one!" + plaintext2 = b"Hi there, this is message two!" + + ciphertext1 = cipher1.encrypt(plaintext1) + ciphertext2 = cipher2.encrypt(plaintext2) + + xor_result = [] + for c1_byte, c2_byte in zip(ciphertext1, ciphertext2): + xor_result.append(c1_byte ^ c2_byte) + xor_bytes = bytes(xor_result) + recovered = [] + for xor_byte, p1_byte in zip(xor_bytes, plaintext1): + recovered.append(xor_byte ^ p1_byte) + recovered_plaintext = bytes(recovered) + assert recovered_plaintext == plaintext2, "Failed to recover second plaintext from XOR pattern" From c84c0f568e2cd04f3f6bc20b2d71158e19ae528b Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:47:53 -0700 Subject: [PATCH 18/35] Move crypto directory under pydatastructs --- {crypto => pydatastructs/crypto}/ChaCha20.py | 0 pydatastructs/crypto/tests/test_chacha20.py | 114 +++++++++++++++++++ 2 files changed, 114 insertions(+) rename {crypto => pydatastructs/crypto}/ChaCha20.py (100%) create mode 100644 pydatastructs/crypto/tests/test_chacha20.py diff --git a/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py similarity index 100% rename from crypto/ChaCha20.py rename to pydatastructs/crypto/ChaCha20.py diff --git a/pydatastructs/crypto/tests/test_chacha20.py b/pydatastructs/crypto/tests/test_chacha20.py new file mode 100644 index 000000000..0826947e9 --- /dev/null +++ b/pydatastructs/crypto/tests/test_chacha20.py @@ -0,0 +1,114 @@ +import random +import string +from crypto.ChaCha20 import ChaCha20 + +VALID_KEY = b"\x00" *32 +assert len(VALID_KEY) == 32, "VALID_KEY must be exactly 32 bytes" +VALID_NONCE = B"\x00" * 12 +assert len(VALID_NONCE) == 12, "VALID_NONCE must be exactly 12 bytes" + +secure_rng = random.SystemRandom() + +def test_invalid_key_size(): + """Test invalid key sizes.""" + try: + ChaCha20(b"short_key", VALID_NONCE) + except ValueError as e: + assert "Key must be exactly 32 bytes" in str(e) + else: + assert False, "ValueError was not raised for short key" + + try: + ChaCha20(b"A" * 33, VALID_NONCE) + except ValueError as e: + assert "Key must be exactly 32 bytes" in str(e) + else: + assert False, "ValueError was not raised for long key" + +def test_invalid_nonce_size(): + """Test invalid nonce sizes.""" + try: + ChaCha20(VALID_KEY, b"short") + except ValueError as e: + assert "Nonce must be exactly 12 bytes" in str(e) + else: + assert False, "ValueError was not raised for short nonce" + + try: + ChaCha20(VALID_KEY, b"A" * 13) + except ValueError as e: + assert "Nonce must be exactly 12 bytes" in str(e) + else: + assert False, "ValueError was not raised for long nonce" + +def test_invalid_counter_values(): + """Test invalid counter values for ChaCha20.""" + for invalid_counter in [-1, -100, -999999]: + try: + ChaCha20(VALID_KEY, VALID_NONCE, counter=invalid_counter) + except ValueError as e: + assert "Counter must be a non-negative integer" in str(e) + else: + assert False, f"ValueError not raised for counter={invalid_counter}" + +def test_encrypt_decrypt(): + """Test encryption and decryption are symmetric.""" + cipher = ChaCha20(VALID_KEY, VALID_NONCE) + plaintext = b"Hello, ChaCha20!" + ciphertext = cipher.encrypt(plaintext) + decrypted = cipher.decrypt(ciphertext) + + assert decrypted == plaintext, "Decryption failed. Plaintext does not match." + +def test_key_reuse_simple(): + """ + Test the vulnerability of key reuse in ChaCha20 encryption. + + This test demonstrates the security flaw of reusing the same key and nonce + for different plaintexts in stream ciphers. It exploits the property that + XORing two ciphertexts from the same keystream cancels out the keystream, + revealing the XOR of the plaintexts. + + Encrypt two different plaintexts with the same key and nonce. + XOR the resulting ciphertexts to remove the keystream, leaving only the XOR of plaintexts. + XOR the result with the first plaintext to recover the second plaintext. + Assert that the recovered plaintext matches the original second plaintext. + + Expected Behavior: + - If the ChaCha20 implementation is correct, reusing the same key and nonce + will expose the XOR relationship between plaintexts. + - The test should successfully recover the second plaintext using XOR operations. + + Assertion: + - Raises an AssertionError if the recovered plaintext does not match the + original second plaintext, indicating a failure in the XOR recovery logic. + + Output: + - Prints the original second plaintext. + - Prints the recovered plaintext (should be identical to the original). + - Displays the XOR result (hexadecimal format) for inspection. + + Security Note: + - This test highlights why it is critical never to reuse the same key and nonce + in stream ciphers like ChaCha20. + """ + + + cipher1 = ChaCha20(VALID_KEY, VALID_NONCE) + cipher2 = ChaCha20(VALID_KEY, VALID_NONCE) + + plaintext1 = b"Hello, this is message one!" + plaintext2 = b"Hi there, this is message two!" + + ciphertext1 = cipher1.encrypt(plaintext1) + ciphertext2 = cipher2.encrypt(plaintext2) + + xor_result = [] + for c1_byte, c2_byte in zip(ciphertext1, ciphertext2): + xor_result.append(c1_byte ^ c2_byte) + xor_bytes = bytes(xor_result) + recovered = [] + for xor_byte, p1_byte in zip(xor_bytes, plaintext1): + recovered.append(xor_byte ^ p1_byte) + recovered_plaintext = bytes(recovered) + assert recovered_plaintext == plaintext2, "Failed to recover second plaintext from XOR pattern" From a2a2eddf0d2784d8605864c3ac7cd774866bc937 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:49:29 -0700 Subject: [PATCH 19/35] Modify import statement of crypto/tests/test_chacha20.py --- crypto/tests/test_chacha20.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py index 0826947e9..5b64947dd 100644 --- a/crypto/tests/test_chacha20.py +++ b/crypto/tests/test_chacha20.py @@ -1,6 +1,6 @@ import random import string -from crypto.ChaCha20 import ChaCha20 +from pydatastructs.crypto.ChaCha20 import ChaCha20 VALID_KEY = b"\x00" *32 assert len(VALID_KEY) == 32, "VALID_KEY must be exactly 32 bytes" From 740f0639dbf82a6a7c68a137609770467d0bba93 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:54:00 -0700 Subject: [PATCH 20/35] Modify import statemnt of tests/test_chacha20.py --- pydatastructs/crypto/tests/test_chacha20.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydatastructs/crypto/tests/test_chacha20.py b/pydatastructs/crypto/tests/test_chacha20.py index 0826947e9..5b64947dd 100644 --- a/pydatastructs/crypto/tests/test_chacha20.py +++ b/pydatastructs/crypto/tests/test_chacha20.py @@ -1,6 +1,6 @@ import random import string -from crypto.ChaCha20 import ChaCha20 +from pydatastructs.crypto.ChaCha20 import ChaCha20 VALID_KEY = b"\x00" *32 assert len(VALID_KEY) == 32, "VALID_KEY must be exactly 32 bytes" From d8ea86581f98c0e5dafa1eade8b88304f17abcda Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:56:15 -0700 Subject: [PATCH 21/35] Delete pydatastructs/crypto/tests/test_chacha20.py --- crypto/tests/test_chacha20.py | 114 ---------------------------------- 1 file changed, 114 deletions(-) delete mode 100644 crypto/tests/test_chacha20.py diff --git a/crypto/tests/test_chacha20.py b/crypto/tests/test_chacha20.py deleted file mode 100644 index 5b64947dd..000000000 --- a/crypto/tests/test_chacha20.py +++ /dev/null @@ -1,114 +0,0 @@ -import random -import string -from pydatastructs.crypto.ChaCha20 import ChaCha20 - -VALID_KEY = b"\x00" *32 -assert len(VALID_KEY) == 32, "VALID_KEY must be exactly 32 bytes" -VALID_NONCE = B"\x00" * 12 -assert len(VALID_NONCE) == 12, "VALID_NONCE must be exactly 12 bytes" - -secure_rng = random.SystemRandom() - -def test_invalid_key_size(): - """Test invalid key sizes.""" - try: - ChaCha20(b"short_key", VALID_NONCE) - except ValueError as e: - assert "Key must be exactly 32 bytes" in str(e) - else: - assert False, "ValueError was not raised for short key" - - try: - ChaCha20(b"A" * 33, VALID_NONCE) - except ValueError as e: - assert "Key must be exactly 32 bytes" in str(e) - else: - assert False, "ValueError was not raised for long key" - -def test_invalid_nonce_size(): - """Test invalid nonce sizes.""" - try: - ChaCha20(VALID_KEY, b"short") - except ValueError as e: - assert "Nonce must be exactly 12 bytes" in str(e) - else: - assert False, "ValueError was not raised for short nonce" - - try: - ChaCha20(VALID_KEY, b"A" * 13) - except ValueError as e: - assert "Nonce must be exactly 12 bytes" in str(e) - else: - assert False, "ValueError was not raised for long nonce" - -def test_invalid_counter_values(): - """Test invalid counter values for ChaCha20.""" - for invalid_counter in [-1, -100, -999999]: - try: - ChaCha20(VALID_KEY, VALID_NONCE, counter=invalid_counter) - except ValueError as e: - assert "Counter must be a non-negative integer" in str(e) - else: - assert False, f"ValueError not raised for counter={invalid_counter}" - -def test_encrypt_decrypt(): - """Test encryption and decryption are symmetric.""" - cipher = ChaCha20(VALID_KEY, VALID_NONCE) - plaintext = b"Hello, ChaCha20!" - ciphertext = cipher.encrypt(plaintext) - decrypted = cipher.decrypt(ciphertext) - - assert decrypted == plaintext, "Decryption failed. Plaintext does not match." - -def test_key_reuse_simple(): - """ - Test the vulnerability of key reuse in ChaCha20 encryption. - - This test demonstrates the security flaw of reusing the same key and nonce - for different plaintexts in stream ciphers. It exploits the property that - XORing two ciphertexts from the same keystream cancels out the keystream, - revealing the XOR of the plaintexts. - - Encrypt two different plaintexts with the same key and nonce. - XOR the resulting ciphertexts to remove the keystream, leaving only the XOR of plaintexts. - XOR the result with the first plaintext to recover the second plaintext. - Assert that the recovered plaintext matches the original second plaintext. - - Expected Behavior: - - If the ChaCha20 implementation is correct, reusing the same key and nonce - will expose the XOR relationship between plaintexts. - - The test should successfully recover the second plaintext using XOR operations. - - Assertion: - - Raises an AssertionError if the recovered plaintext does not match the - original second plaintext, indicating a failure in the XOR recovery logic. - - Output: - - Prints the original second plaintext. - - Prints the recovered plaintext (should be identical to the original). - - Displays the XOR result (hexadecimal format) for inspection. - - Security Note: - - This test highlights why it is critical never to reuse the same key and nonce - in stream ciphers like ChaCha20. - """ - - - cipher1 = ChaCha20(VALID_KEY, VALID_NONCE) - cipher2 = ChaCha20(VALID_KEY, VALID_NONCE) - - plaintext1 = b"Hello, this is message one!" - plaintext2 = b"Hi there, this is message two!" - - ciphertext1 = cipher1.encrypt(plaintext1) - ciphertext2 = cipher2.encrypt(plaintext2) - - xor_result = [] - for c1_byte, c2_byte in zip(ciphertext1, ciphertext2): - xor_result.append(c1_byte ^ c2_byte) - xor_bytes = bytes(xor_result) - recovered = [] - for xor_byte, p1_byte in zip(xor_bytes, plaintext1): - recovered.append(xor_byte ^ p1_byte) - recovered_plaintext = bytes(recovered) - assert recovered_plaintext == plaintext2, "Failed to recover second plaintext from XOR pattern" From 16215674991b3c6f6018005e6a96da7f5666bd86 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 13:57:34 -0700 Subject: [PATCH 22/35] Add __init__.py file --- pydatastructs/crypto/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 pydatastructs/crypto/__init__.py diff --git a/pydatastructs/crypto/__init__.py b/pydatastructs/crypto/__init__.py new file mode 100644 index 000000000..e69de29bb From b041268230cc5b0ce2271ee5fa8c3e0c1e9bf6ab Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 14:01:01 -0700 Subject: [PATCH 23/35] Modify __init__.py --- pydatastructs/crypto/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pydatastructs/crypto/__init__.py b/pydatastructs/crypto/__init__.py index e69de29bb..8ac1efd5e 100644 --- a/pydatastructs/crypto/__init__.py +++ b/pydatastructs/crypto/__init__.py @@ -0,0 +1,2 @@ +from .ChaCha20 import ChaCha20 +__all__ = ["ChaCha20"] \ No newline at end of file From 34a1e405d024f595ef83c61f439d2c6167dc2201 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 14:05:09 -0700 Subject: [PATCH 24/35] Add __init__.py for crypto/tests/test_chacha20.py --- pydatastructs/crypto/tests/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 pydatastructs/crypto/tests/__init__.py diff --git a/pydatastructs/crypto/tests/__init__.py b/pydatastructs/crypto/tests/__init__.py new file mode 100644 index 000000000..e69de29bb From 2b915d06adef324f7c4d73c54656e5c0b04e2771 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 14:23:07 -0700 Subject: [PATCH 25/35] Modify test_chacha20.py --- pydatastructs/crypto/tests/test_chacha20.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydatastructs/crypto/tests/test_chacha20.py b/pydatastructs/crypto/tests/test_chacha20.py index 5b64947dd..76edf7082 100644 --- a/pydatastructs/crypto/tests/test_chacha20.py +++ b/pydatastructs/crypto/tests/test_chacha20.py @@ -2,7 +2,7 @@ import string from pydatastructs.crypto.ChaCha20 import ChaCha20 -VALID_KEY = b"\x00" *32 +VALID_KEY = B"\x00" *32 assert len(VALID_KEY) == 32, "VALID_KEY must be exactly 32 bytes" VALID_NONCE = B"\x00" * 12 assert len(VALID_NONCE) == 12, "VALID_NONCE must be exactly 12 bytes" From d19ad9fbe15600b02f34d4d471d0b598913cda4d Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 14:29:14 -0700 Subject: [PATCH 26/35] enhance ChaCha20 implementation with __new__, __init__, __repr__, reset method --- pydatastructs/crypto/ChaCha20.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index aed92f444..1b7441042 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -21,11 +21,25 @@ def __new__(cls, key: bytes, nonce: bytes, counter: int = 0): raise ValueError("Key must be exactly 32 bytes (256 bits).") if not isinstance(nonce, bytes) or len(nonce) != 12: raise ValueError("Nonce must be exactly 12 bytes (96 bits).") + if not isinstance(counter, int) or counter < 0: + raise ValueError("Counter must be a non-negative integer.") instance = super().__new__(cls) instance.key = key instance.nonce = nonce instance.counter = counter return instance + + def __init__(self, key: bytes, nonce: bytes, counter: int = 0): + """Initializes the ChaCha20 object.""" + # Guard against multiple initializations + if hasattr(self, "_initialized") and self._initialized: + return + self._initialized = True + + def __repr__(self): + """Returns a string representation of the object for debugging.""" + return f"" + def _quarter_round(self, state: List[int], a: int, b: int, c: int, d: int): state[a] = (state[a] + state[b]) % (2**32) state[d] ^= state[a] @@ -133,4 +147,10 @@ def decrypt(self, ciphertext: bytes) -> bytes: Returns: bytes: The resulting plaintext. """ - return self.apply_keystream(ciphertext) \ No newline at end of file + return self.apply_keystream(ciphertext) + + def reset(self, counter: int = 0): + """Resets the ChaCha20 counter to the specified value (default is 0).""" + if not isinstance(counter, int) or counter < 0: + raise ValueError("Counter must be a non-negative integer.") + self.counter = counter \ No newline at end of file From 8637d0c3977e3cf66bc9f270cd95607669be25d6 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 14:36:18 -0700 Subject: [PATCH 27/35] Remove extra trailing whitespace --- pydatastructs/crypto/tests/test_chacha20.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pydatastructs/crypto/tests/test_chacha20.py b/pydatastructs/crypto/tests/test_chacha20.py index 76edf7082..3e4fd188a 100644 --- a/pydatastructs/crypto/tests/test_chacha20.py +++ b/pydatastructs/crypto/tests/test_chacha20.py @@ -64,10 +64,10 @@ def test_key_reuse_simple(): """ Test the vulnerability of key reuse in ChaCha20 encryption. - This test demonstrates the security flaw of reusing the same key and nonce - for different plaintexts in stream ciphers. It exploits the property that - XORing two ciphertexts from the same keystream cancels out the keystream, - revealing the XOR of the plaintexts. + This test demonstrates the security flaw of reusing the same key and nonce + for different plaintexts in stream ciphers. It exploits the property that + XORing two ciphertexts from the same keystream cancels out the keystream, + revealing the XOR of the plaintexts. Encrypt two different plaintexts with the same key and nonce. XOR the resulting ciphertexts to remove the keystream, leaving only the XOR of plaintexts. @@ -75,12 +75,12 @@ def test_key_reuse_simple(): Assert that the recovered plaintext matches the original second plaintext. Expected Behavior: - - If the ChaCha20 implementation is correct, reusing the same key and nonce + - If the ChaCha20 implementation is correct, reusing the same key and nonce will expose the XOR relationship between plaintexts. - The test should successfully recover the second plaintext using XOR operations. Assertion: - - Raises an AssertionError if the recovered plaintext does not match the + - Raises an AssertionError if the recovered plaintext does not match the original second plaintext, indicating a failure in the XOR recovery logic. Output: @@ -89,11 +89,11 @@ def test_key_reuse_simple(): - Displays the XOR result (hexadecimal format) for inspection. Security Note: - - This test highlights why it is critical never to reuse the same key and nonce + - This test highlights why it is critical never to reuse the same key and nonce in stream ciphers like ChaCha20. """ - + cipher1 = ChaCha20(VALID_KEY, VALID_NONCE) cipher2 = ChaCha20(VALID_KEY, VALID_NONCE) From be3c7178854c7a9aafe63296921508e42c52a680 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 14:40:51 -0700 Subject: [PATCH 28/35] Add a newline at the end of files --- pydatastructs/crypto/ChaCha20.py | 22 +- pydatastructs/crypto/__init__.py | 2 +- pydatastructs/trees/heaps.py | 1164 +++++++++++------------ pydatastructs/trees/tests/test_heaps.py | 472 ++++----- 4 files changed, 830 insertions(+), 830 deletions(-) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index 1b7441042..0b7cade4b 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -6,7 +6,7 @@ class ChaCha20: """ Implementation of the ChaCha20 stream cipher. - + Attributes ---------- key : bytes @@ -28,7 +28,7 @@ def __new__(cls, key: bytes, nonce: bytes, counter: int = 0): instance.nonce = nonce instance.counter = counter return instance - + def __init__(self, key: bytes, nonce: bytes, counter: int = 0): """Initializes the ChaCha20 object.""" # Guard against multiple initializations @@ -81,10 +81,10 @@ def _chacha20_block(self, counter: int) -> bytes: self._double_round(working_state) final_state = (working_state + state) % (2**32) return struct.pack('<16I', *final_state.flatten()) - + def _apply_keystream(self, data: bytes) -> bytes: """ - Applies the ChaCha20 keystream to the input data (plaintext or ciphertext) + Applies the ChaCha20 keystream to the input data (plaintext or ciphertext) to perform encryption or decryption. This method processes the input data in 64-byte blocks. For each block: @@ -92,14 +92,14 @@ def _apply_keystream(self, data: bytes) -> bytes: - Each byte of the input block is XORed with the corresponding keystream byte. - The XORed result is appended to the output. - The same function is used for both encryption and decryption because + The same function is used for both encryption and decryption because XORing the ciphertext with the same keystream returns the original plaintext. Args: data (bytes): The input data to be encrypted or decrypted (plaintext or ciphertext). Returns: - bytes: The result of XORing the input data with the ChaCha20 keystream + bytes: The result of XORing the input data with the ChaCha20 keystream (ciphertext if plaintext was provided, plaintext if ciphertext was provided). """ if len(data) == 0: @@ -123,7 +123,7 @@ def encrypt(self, plaintext: bytes) -> bytes: """ Encrypts the given plaintext using the ChaCha20 stream cipher. - This method uses the ChaCha20 keystream generated from the + This method uses the ChaCha20 keystream generated from the key, nonce, and counter to XOR with the plaintext, producing ciphertext. Args: @@ -133,12 +133,12 @@ def encrypt(self, plaintext: bytes) -> bytes: bytes: The resulting ciphertext. """ return self._apply_keystream(plaintext) - + def decrypt(self, ciphertext: bytes) -> bytes: """ Decrypts the given ciphertext using the ChaCha20 stream cipher. - Since ChaCha20 uses XOR for encryption, decryption is performed + Since ChaCha20 uses XOR for encryption, decryption is performed using the same keystream and XOR operation. Args: @@ -148,9 +148,9 @@ def decrypt(self, ciphertext: bytes) -> bytes: bytes: The resulting plaintext. """ return self.apply_keystream(ciphertext) - + def reset(self, counter: int = 0): """Resets the ChaCha20 counter to the specified value (default is 0).""" if not isinstance(counter, int) or counter < 0: raise ValueError("Counter must be a non-negative integer.") - self.counter = counter \ No newline at end of file + self.counter = counter diff --git a/pydatastructs/crypto/__init__.py b/pydatastructs/crypto/__init__.py index 8ac1efd5e..ea6615cb8 100644 --- a/pydatastructs/crypto/__init__.py +++ b/pydatastructs/crypto/__init__.py @@ -1,2 +1,2 @@ from .ChaCha20 import ChaCha20 -__all__ = ["ChaCha20"] \ No newline at end of file +__all__ = ["ChaCha20"] diff --git a/pydatastructs/trees/heaps.py b/pydatastructs/trees/heaps.py index 12133a6f1..ffa2323ae 100644 --- a/pydatastructs/trees/heaps.py +++ b/pydatastructs/trees/heaps.py @@ -1,582 +1,582 @@ -from pydatastructs.utils.misc_util import ( - _check_type, TreeNode, BinomialTreeNode, - Backend, raise_if_backend_is_not_python) -from pydatastructs.linear_data_structures.arrays import ( - DynamicOneDimensionalArray, Array) -from pydatastructs.miscellaneous_data_structures.binomial_trees import BinomialTree - -__all__ = [ - 'BinaryHeap', - 'TernaryHeap', - 'DHeap', - 'BinomialHeap' -] - -class Heap(object): - """ - Abstract class for representing heaps. - """ - pass - - -class DHeap(Heap): - """ - Represents D-ary Heap. - - Parameters - ========== - - elements: list, tuple, Array - Optional, by default 'None'. - list/tuple/Array of initial TreeNode in Heap. - heap_property: str - If the key stored in each node is - either greater than or equal to - the keys in the node's children - then pass 'max'. - If the key stored in each node is - either less than or equal to - the keys in the node's children - then pass 'min'. - By default, the heap property is - set to 'min'. - backend: pydatastructs.Backend - The backend to be used. - Optional, by default, the best available - backend is used. - - Examples - ======== - - >>> from pydatastructs.trees.heaps import DHeap - >>> min_heap = DHeap(heap_property="min", d=3) - >>> min_heap.insert(1, 1) - >>> min_heap.insert(5, 5) - >>> min_heap.insert(7, 7) - >>> min_heap.extract().key - 1 - >>> min_heap.insert(4, 4) - >>> min_heap.extract().key - 4 - - >>> max_heap = DHeap(heap_property='max', d=2) - >>> max_heap.insert(1, 1) - >>> max_heap.insert(5, 5) - >>> max_heap.insert(7, 7) - >>> max_heap.extract().key - 7 - >>> max_heap.insert(6, 6) - >>> max_heap.extract().key - 6 - - References - ========== - - .. [1] https://en.wikipedia.org/wiki/D-ary_heap - """ - __slots__ = ['_comp', 'heap', 'd', 'heap_property', '_last_pos_filled'] - - def __new__(cls, elements=None, heap_property="min", d=4, - **kwargs): - raise_if_backend_is_not_python( - cls, kwargs.get('backend', Backend.PYTHON)) - obj = Heap.__new__(cls) - obj.heap_property = heap_property - obj.d = d - if heap_property == "min": - obj._comp = lambda key_parent, key_child: key_parent <= key_child - elif heap_property == "max": - obj._comp = lambda key_parent, key_child: key_parent >= key_child - else: - raise ValueError("%s is invalid heap property"%(heap_property)) - if elements is None: - elements = DynamicOneDimensionalArray(TreeNode, 0) - elif _check_type(elements, (list,tuple)): - elements = DynamicOneDimensionalArray(TreeNode, len(elements), elements) - elif _check_type(elements, Array): - elements = DynamicOneDimensionalArray(TreeNode, len(elements), elements._data) - else: - raise ValueError(f'Expected a list/tuple/Array of TreeNode got {type(elements)}') - obj.heap = elements - obj._last_pos_filled = obj.heap._last_pos_filled - obj._build() - return obj - - @classmethod - def methods(cls): - return ['__new__', 'insert', 'extract', '__str__', 'is_empty'] - - def _build(self): - for i in range(self._last_pos_filled + 1): - self.heap[i]._leftmost, self.heap[i]._rightmost = \ - self.d*i + 1, self.d*i + self.d - for i in range((self._last_pos_filled + 1)//self.d, -1, -1): - self._heapify(i) - - def _swap(self, idx1, idx2): - idx1_key, idx1_data = \ - self.heap[idx1].key, self.heap[idx1].data - self.heap[idx1].key, self.heap[idx1].data = \ - self.heap[idx2].key, self.heap[idx2].data - self.heap[idx2].key, self.heap[idx2].data = \ - idx1_key, idx1_data - - def _heapify(self, i): - while True: - target = i - l = self.d*i + 1 - r = self.d*i + self.d - - for j in range(l, r+1): - if j <= self._last_pos_filled: - target = j if self._comp(self.heap[j].key, self.heap[target].key) \ - else target - else: - break - - if target != i: - self._swap(target, i) - i = target - else: - break - - def insert(self, key, data=None): - """ - Insert a new element to the heap according to heap property. - - Parameters - ========== - - key - The key for comparison. - data - The data to be inserted. - - Returns - ======= - - None - """ - new_node = TreeNode(key, data) - self.heap.append(new_node) - self._last_pos_filled += 1 - i = self._last_pos_filled - self.heap[i]._leftmost, self.heap[i]._rightmost = self.d*i + 1, self.d*i + self.d - - while True: - parent = (i - 1)//self.d - if i == 0 or self._comp(self.heap[parent].key, self.heap[i].key): - break - else: - self._swap(i, parent) - i = parent - - def extract(self): - """ - Extract root element of the Heap. - - Returns - ======= - - root_element: TreeNode - The TreeNode at the root of the heap, - if the heap is not empty. - - None - If the heap is empty. - """ - if self._last_pos_filled == -1: - raise IndexError("Heap is empty.") - else: - element_to_be_extracted = TreeNode(self.heap[0].key, self.heap[0].data) - self._swap(0, self._last_pos_filled) - self.heap.delete(self._last_pos_filled) - self._last_pos_filled -= 1 - self._heapify(0) - return element_to_be_extracted - - def __str__(self): - to_be_printed = ['' for i in range(self._last_pos_filled + 1)] - for i in range(self._last_pos_filled + 1): - node = self.heap[i] - if node._leftmost <= self._last_pos_filled: - if node._rightmost <= self._last_pos_filled: - children = list(range(node._leftmost, node._rightmost + 1)) - else: - children = list(range(node._leftmost, self._last_pos_filled + 1)) - else: - children = [] - to_be_printed[i] = (node.key, node.data, children) - return str(to_be_printed) - - @property - def is_empty(self): - """ - Checks if the heap is empty. - """ - return self.heap._last_pos_filled == -1 - - -class BinaryHeap(DHeap): - """ - Represents Binary Heap. - - Parameters - ========== - - elements: list, tuple - Optional, by default 'None'. - List/tuple of initial elements in Heap. - heap_property: str - If the key stored in each node is - either greater than or equal to - the keys in the node's children - then pass 'max'. - If the key stored in each node is - either less than or equal to - the keys in the node's children - then pass 'min'. - By default, the heap property is - set to 'min'. - backend: pydatastructs.Backend - The backend to be used. - Optional, by default, the best available - backend is used. - - Examples - ======== - - >>> from pydatastructs.trees.heaps import BinaryHeap - >>> min_heap = BinaryHeap(heap_property="min") - >>> min_heap.insert(1, 1) - >>> min_heap.insert(5, 5) - >>> min_heap.insert(7, 7) - >>> min_heap.extract().key - 1 - >>> min_heap.insert(4, 4) - >>> min_heap.extract().key - 4 - - >>> max_heap = BinaryHeap(heap_property='max') - >>> max_heap.insert(1, 1) - >>> max_heap.insert(5, 5) - >>> max_heap.insert(7, 7) - >>> max_heap.extract().key - 7 - >>> max_heap.insert(6, 6) - >>> max_heap.extract().key - 6 - - References - ========== - - .. [1] https://en.m.wikipedia.org/wiki/Binary_heap - """ - def __new__(cls, elements=None, heap_property="min", - **kwargs): - raise_if_backend_is_not_python( - cls, kwargs.get('backend', Backend.PYTHON)) - obj = DHeap.__new__(cls, elements, heap_property, 2) - return obj - - @classmethod - def methods(cls): - return ['__new__'] - - -class TernaryHeap(DHeap): - """ - Represents Ternary Heap. - - Parameters - ========== - - elements: list, tuple - Optional, by default 'None'. - List/tuple of initial elements in Heap. - heap_property: str - If the key stored in each node is - either greater than or equal to - the keys in the node's children - then pass 'max'. - If the key stored in each node is - either less than or equal to - the keys in the node's children - then pass 'min'. - By default, the heap property is - set to 'min'. - backend: pydatastructs.Backend - The backend to be used. - Optional, by default, the best available - backend is used. - - Examples - ======== - - >>> from pydatastructs.trees.heaps import TernaryHeap - >>> min_heap = TernaryHeap(heap_property="min") - >>> min_heap.insert(1, 1) - >>> min_heap.insert(5, 5) - >>> min_heap.insert(7, 7) - >>> min_heap.insert(3, 3) - >>> min_heap.extract().key - 1 - >>> min_heap.insert(4, 4) - >>> min_heap.extract().key - 3 - - >>> max_heap = TernaryHeap(heap_property='max') - >>> max_heap.insert(1, 1) - >>> max_heap.insert(5, 5) - >>> max_heap.insert(7, 7) - >>> min_heap.insert(3, 3) - >>> max_heap.extract().key - 7 - >>> max_heap.insert(6, 6) - >>> max_heap.extract().key - 6 - - References - ========== - - .. [1] https://en.wikipedia.org/wiki/D-ary_heap - .. [2] https://ece.uwaterloo.ca/~dwharder/aads/Algorithms/d-ary_heaps/Ternary_heaps/ - """ - def __new__(cls, elements=None, heap_property="min", - **kwargs): - raise_if_backend_is_not_python( - cls, kwargs.get('backend', Backend.PYTHON)) - obj = DHeap.__new__(cls, elements, heap_property, 3) - return obj - - @classmethod - def methods(cls): - return ['__new__'] - - -class BinomialHeap(Heap): - """ - Represents binomial heap. - - Parameters - ========== - - root_list: list/tuple/Array - By default, [] - The list of BinomialTree object references - in sorted order. - backend: pydatastructs.Backend - The backend to be used. - Optional, by default, the best available - backend is used. - - Examples - ======== - - >>> from pydatastructs import BinomialHeap - >>> b = BinomialHeap() - >>> b.insert(1, 1) - >>> b.insert(2, 2) - >>> b.find_minimum().key - 1 - >>> b.find_minimum().children[0].key - 2 - - References - ========== - - .. [1] https://en.wikipedia.org/wiki/Binomial_heap - """ - __slots__ = ['root_list'] - - def __new__(cls, root_list=None, **kwargs): - raise_if_backend_is_not_python( - cls, kwargs.get('backend', Backend.PYTHON)) - if root_list is None: - root_list = [] - if not all((_check_type(root, BinomialTree)) - for root in root_list): - raise TypeError("The root_list should contain " - "references to objects of BinomialTree.") - obj = Heap.__new__(cls) - obj.root_list = root_list - return obj - - @classmethod - def methods(cls): - return ['__new__', 'merge_tree', 'merge', 'insert', - 'find_minimum', 'is_emtpy', 'decrease_key', 'delete', - 'delete_minimum'] - - def merge_tree(self, tree1, tree2): - """ - Merges two BinomialTree objects. - - Parameters - ========== - - tree1: BinomialTree - - tree2: BinomialTree - """ - if (not _check_type(tree1, BinomialTree)) or \ - (not _check_type(tree2, BinomialTree)): - raise TypeError("Both the trees should be of type " - "BinomalTree.") - ret_value = None - if tree1.root.key <= tree2.root.key: - tree1.add_sub_tree(tree2) - ret_value = tree1 - else: - tree2.add_sub_tree(tree1) - ret_value = tree2 - return ret_value - - def _merge_heap_last_new_tree(self, new_root_list, new_tree): - """ - Merges last tree node in root list with the incoming tree. - """ - pos = -1 - if len(new_root_list) > 0 and new_root_list[pos].order == new_tree.order: - new_root_list[pos] = self.merge_tree(new_root_list[pos], new_tree) - else: - new_root_list.append(new_tree) - - def merge(self, other_heap): - """ - Merges current binomial heap with the given binomial heap. - - Parameters - ========== - - other_heap: BinomialHeap - """ - if not _check_type(other_heap, BinomialHeap): - raise TypeError("Other heap is not of type BinomialHeap.") - new_root_list = [] - i, j = 0, 0 - while (i < len(self.root_list)) and \ - (j < len(other_heap.root_list)): - new_tree = None - while self.root_list[i] is None: - i += 1 - while other_heap.root_list[j] is None: - j += 1 - if self.root_list[i].order == other_heap.root_list[j].order: - new_tree = self.merge_tree(self.root_list[i], - other_heap.root_list[j]) - i += 1 - j += 1 - else: - if self.root_list[i].order < other_heap.root_list[j].order: - new_tree = self.root_list[i] - i += 1 - else: - new_tree = other_heap.root_list[j] - j += 1 - self._merge_heap_last_new_tree(new_root_list, new_tree) - - while i < len(self.root_list): - new_tree = self.root_list[i] - self._merge_heap_last_new_tree(new_root_list, new_tree) - i += 1 - while j < len(other_heap.root_list): - new_tree = other_heap.root_list[j] - self._merge_heap_last_new_tree(new_root_list, new_tree) - j += 1 - self.root_list = new_root_list - - def insert(self, key, data=None): - """ - Inserts new node with the given key and data. - - key - The key of the node which can be operated - upon by relational operators. - - data - The data to be stored in the new node. - """ - new_node = BinomialTreeNode(key, data) - new_tree = BinomialTree(root=new_node, order=0) - new_heap = BinomialHeap(root_list=[new_tree]) - self.merge(new_heap) - - def find_minimum(self, **kwargs): - """ - Finds the node with the minimum key. - - Returns - ======= - - min_node: BinomialTreeNode - """ - if self.is_empty: - raise IndexError("Binomial heap is empty.") - min_node = None - idx, min_idx = 0, None - for tree in self.root_list: - if ((min_node is None) or - (tree is not None and tree.root is not None and - min_node.key > tree.root.key)): - min_node = tree.root - min_idx = idx - idx += 1 - if kwargs.get('get_index', None) is not None: - return min_node, min_idx - return min_node - - def delete_minimum(self): - """ - Deletes the node with minimum key. - """ - min_node, min_idx = self.find_minimum(get_index=True) - child_root_list = [] - for k, child in enumerate(min_node.children): - if child is not None: - child_root_list.append(BinomialTree(root=child, order=k)) - self.root_list.remove(self.root_list[min_idx]) - child_heap = BinomialHeap(root_list=child_root_list) - self.merge(child_heap) - - @property - def is_empty(self): - return not self.root_list - - def decrease_key(self, node, new_key): - """ - Decreases the key of the given node. - - Parameters - ========== - - node: BinomialTreeNode - The node whose key is to be reduced. - new_key - The new key of the given node, - should be less than the current key. - """ - if node.key <= new_key: - raise ValueError("The new key " - "should be less than current node's key.") - node.key = new_key - while ((not node.is_root) and - (node.parent.key > node.key)): - node.parent.key, node.key = \ - node.key, node.parent.key - node.parent.data, node.data = \ - node.data, node.parent.data - node = node.parent - - def delete(self, node): - """ - Deletes the given node. - - Parameters - ========== - - node: BinomialTreeNode - The node which is to be deleted. - """ - self.decrease_key(node, self.find_minimum().key - 1) - self.delete_minimum() +from pydatastructs.utils.misc_util import ( + _check_type, TreeNode, BinomialTreeNode, + Backend, raise_if_backend_is_not_python) +from pydatastructs.linear_data_structures.arrays import ( + DynamicOneDimensionalArray, Array) +from pydatastructs.miscellaneous_data_structures.binomial_trees import BinomialTree + +__all__ = [ + 'BinaryHeap', + 'TernaryHeap', + 'DHeap', + 'BinomialHeap' +] + +class Heap(object): + """ + Abstract class for representing heaps. + """ + pass + + +class DHeap(Heap): + """ + Represents D-ary Heap. + + Parameters + ========== + + elements: list, tuple, Array + Optional, by default 'None'. + list/tuple/Array of initial TreeNode in Heap. + heap_property: str + If the key stored in each node is + either greater than or equal to + the keys in the node's children + then pass 'max'. + If the key stored in each node is + either less than or equal to + the keys in the node's children + then pass 'min'. + By default, the heap property is + set to 'min'. + backend: pydatastructs.Backend + The backend to be used. + Optional, by default, the best available + backend is used. + + Examples + ======== + + >>> from pydatastructs.trees.heaps import DHeap + >>> min_heap = DHeap(heap_property="min", d=3) + >>> min_heap.insert(1, 1) + >>> min_heap.insert(5, 5) + >>> min_heap.insert(7, 7) + >>> min_heap.extract().key + 1 + >>> min_heap.insert(4, 4) + >>> min_heap.extract().key + 4 + + >>> max_heap = DHeap(heap_property='max', d=2) + >>> max_heap.insert(1, 1) + >>> max_heap.insert(5, 5) + >>> max_heap.insert(7, 7) + >>> max_heap.extract().key + 7 + >>> max_heap.insert(6, 6) + >>> max_heap.extract().key + 6 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/D-ary_heap + """ + __slots__ = ['_comp', 'heap', 'd', 'heap_property', '_last_pos_filled'] + + def __new__(cls, elements=None, heap_property="min", d=4, + **kwargs): + raise_if_backend_is_not_python( + cls, kwargs.get('backend', Backend.PYTHON)) + obj = Heap.__new__(cls) + obj.heap_property = heap_property + obj.d = d + if heap_property == "min": + obj._comp = lambda key_parent, key_child: key_parent <= key_child + elif heap_property == "max": + obj._comp = lambda key_parent, key_child: key_parent >= key_child + else: + raise ValueError("%s is invalid heap property"%(heap_property)) + if elements is None: + elements = DynamicOneDimensionalArray(TreeNode, 0) + elif _check_type(elements, (list,tuple)): + elements = DynamicOneDimensionalArray(TreeNode, len(elements), elements) + elif _check_type(elements, Array): + elements = DynamicOneDimensionalArray(TreeNode, len(elements), elements._data) + else: + raise ValueError(f'Expected a list/tuple/Array of TreeNode got {type(elements)}') + obj.heap = elements + obj._last_pos_filled = obj.heap._last_pos_filled + obj._build() + return obj + + @classmethod + def methods(cls): + return ['__new__', 'insert', 'extract', '__str__', 'is_empty'] + + def _build(self): + for i in range(self._last_pos_filled + 1): + self.heap[i]._leftmost, self.heap[i]._rightmost = \ + self.d*i + 1, self.d*i + self.d + for i in range((self._last_pos_filled + 1)//self.d, -1, -1): + self._heapify(i) + + def _swap(self, idx1, idx2): + idx1_key, idx1_data = \ + self.heap[idx1].key, self.heap[idx1].data + self.heap[idx1].key, self.heap[idx1].data = \ + self.heap[idx2].key, self.heap[idx2].data + self.heap[idx2].key, self.heap[idx2].data = \ + idx1_key, idx1_data + + def _heapify(self, i): + while True: + target = i + l = self.d*i + 1 + r = self.d*i + self.d + + for j in range(l, r+1): + if j <= self._last_pos_filled: + target = j if self._comp(self.heap[j].key, self.heap[target].key) \ + else target + else: + break + + if target != i: + self._swap(target, i) + i = target + else: + break + + def insert(self, key, data=None): + """ + Insert a new element to the heap according to heap property. + + Parameters + ========== + + key + The key for comparison. + data + The data to be inserted. + + Returns + ======= + + None + """ + new_node = TreeNode(key, data) + self.heap.append(new_node) + self._last_pos_filled += 1 + i = self._last_pos_filled + self.heap[i]._leftmost, self.heap[i]._rightmost = self.d*i + 1, self.d*i + self.d + + while True: + parent = (i - 1)//self.d + if i == 0 or self._comp(self.heap[parent].key, self.heap[i].key): + break + else: + self._swap(i, parent) + i = parent + + def extract(self): + """ + Extract root element of the Heap. + + Returns + ======= + + root_element: TreeNode + The TreeNode at the root of the heap, + if the heap is not empty. + + None + If the heap is empty. + """ + if self._last_pos_filled == -1: + raise IndexError("Heap is empty.") + else: + element_to_be_extracted = TreeNode(self.heap[0].key, self.heap[0].data) + self._swap(0, self._last_pos_filled) + self.heap.delete(self._last_pos_filled) + self._last_pos_filled -= 1 + self._heapify(0) + return element_to_be_extracted + + def __str__(self): + to_be_printed = ['' for i in range(self._last_pos_filled + 1)] + for i in range(self._last_pos_filled + 1): + node = self.heap[i] + if node._leftmost <= self._last_pos_filled: + if node._rightmost <= self._last_pos_filled: + children = list(range(node._leftmost, node._rightmost + 1)) + else: + children = list(range(node._leftmost, self._last_pos_filled + 1)) + else: + children = [] + to_be_printed[i] = (node.key, node.data, children) + return str(to_be_printed) + + @property + def is_empty(self): + """ + Checks if the heap is empty. + """ + return self.heap._last_pos_filled == -1 + + +class BinaryHeap(DHeap): + """ + Represents Binary Heap. + + Parameters + ========== + + elements: list, tuple + Optional, by default 'None'. + List/tuple of initial elements in Heap. + heap_property: str + If the key stored in each node is + either greater than or equal to + the keys in the node's children + then pass 'max'. + If the key stored in each node is + either less than or equal to + the keys in the node's children + then pass 'min'. + By default, the heap property is + set to 'min'. + backend: pydatastructs.Backend + The backend to be used. + Optional, by default, the best available + backend is used. + + Examples + ======== + + >>> from pydatastructs.trees.heaps import BinaryHeap + >>> min_heap = BinaryHeap(heap_property="min") + >>> min_heap.insert(1, 1) + >>> min_heap.insert(5, 5) + >>> min_heap.insert(7, 7) + >>> min_heap.extract().key + 1 + >>> min_heap.insert(4, 4) + >>> min_heap.extract().key + 4 + + >>> max_heap = BinaryHeap(heap_property='max') + >>> max_heap.insert(1, 1) + >>> max_heap.insert(5, 5) + >>> max_heap.insert(7, 7) + >>> max_heap.extract().key + 7 + >>> max_heap.insert(6, 6) + >>> max_heap.extract().key + 6 + + References + ========== + + .. [1] https://en.m.wikipedia.org/wiki/Binary_heap + """ + def __new__(cls, elements=None, heap_property="min", + **kwargs): + raise_if_backend_is_not_python( + cls, kwargs.get('backend', Backend.PYTHON)) + obj = DHeap.__new__(cls, elements, heap_property, 2) + return obj + + @classmethod + def methods(cls): + return ['__new__'] + + +class TernaryHeap(DHeap): + """ + Represents Ternary Heap. + + Parameters + ========== + + elements: list, tuple + Optional, by default 'None'. + List/tuple of initial elements in Heap. + heap_property: str + If the key stored in each node is + either greater than or equal to + the keys in the node's children + then pass 'max'. + If the key stored in each node is + either less than or equal to + the keys in the node's children + then pass 'min'. + By default, the heap property is + set to 'min'. + backend: pydatastructs.Backend + The backend to be used. + Optional, by default, the best available + backend is used. + + Examples + ======== + + >>> from pydatastructs.trees.heaps import TernaryHeap + >>> min_heap = TernaryHeap(heap_property="min") + >>> min_heap.insert(1, 1) + >>> min_heap.insert(5, 5) + >>> min_heap.insert(7, 7) + >>> min_heap.insert(3, 3) + >>> min_heap.extract().key + 1 + >>> min_heap.insert(4, 4) + >>> min_heap.extract().key + 3 + + >>> max_heap = TernaryHeap(heap_property='max') + >>> max_heap.insert(1, 1) + >>> max_heap.insert(5, 5) + >>> max_heap.insert(7, 7) + >>> min_heap.insert(3, 3) + >>> max_heap.extract().key + 7 + >>> max_heap.insert(6, 6) + >>> max_heap.extract().key + 6 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/D-ary_heap + .. [2] https://ece.uwaterloo.ca/~dwharder/aads/Algorithms/d-ary_heaps/Ternary_heaps/ + """ + def __new__(cls, elements=None, heap_property="min", + **kwargs): + raise_if_backend_is_not_python( + cls, kwargs.get('backend', Backend.PYTHON)) + obj = DHeap.__new__(cls, elements, heap_property, 3) + return obj + + @classmethod + def methods(cls): + return ['__new__'] + + +class BinomialHeap(Heap): + """ + Represents binomial heap. + + Parameters + ========== + + root_list: list/tuple/Array + By default, [] + The list of BinomialTree object references + in sorted order. + backend: pydatastructs.Backend + The backend to be used. + Optional, by default, the best available + backend is used. + + Examples + ======== + + >>> from pydatastructs import BinomialHeap + >>> b = BinomialHeap() + >>> b.insert(1, 1) + >>> b.insert(2, 2) + >>> b.find_minimum().key + 1 + >>> b.find_minimum().children[0].key + 2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Binomial_heap + """ + __slots__ = ['root_list'] + + def __new__(cls, root_list=None, **kwargs): + raise_if_backend_is_not_python( + cls, kwargs.get('backend', Backend.PYTHON)) + if root_list is None: + root_list = [] + if not all((_check_type(root, BinomialTree)) + for root in root_list): + raise TypeError("The root_list should contain " + "references to objects of BinomialTree.") + obj = Heap.__new__(cls) + obj.root_list = root_list + return obj + + @classmethod + def methods(cls): + return ['__new__', 'merge_tree', 'merge', 'insert', + 'find_minimum', 'is_emtpy', 'decrease_key', 'delete', + 'delete_minimum'] + + def merge_tree(self, tree1, tree2): + """ + Merges two BinomialTree objects. + + Parameters + ========== + + tree1: BinomialTree + + tree2: BinomialTree + """ + if (not _check_type(tree1, BinomialTree)) or \ + (not _check_type(tree2, BinomialTree)): + raise TypeError("Both the trees should be of type " + "BinomalTree.") + ret_value = None + if tree1.root.key <= tree2.root.key: + tree1.add_sub_tree(tree2) + ret_value = tree1 + else: + tree2.add_sub_tree(tree1) + ret_value = tree2 + return ret_value + + def _merge_heap_last_new_tree(self, new_root_list, new_tree): + """ + Merges last tree node in root list with the incoming tree. + """ + pos = -1 + if len(new_root_list) > 0 and new_root_list[pos].order == new_tree.order: + new_root_list[pos] = self.merge_tree(new_root_list[pos], new_tree) + else: + new_root_list.append(new_tree) + + def merge(self, other_heap): + """ + Merges current binomial heap with the given binomial heap. + + Parameters + ========== + + other_heap: BinomialHeap + """ + if not _check_type(other_heap, BinomialHeap): + raise TypeError("Other heap is not of type BinomialHeap.") + new_root_list = [] + i, j = 0, 0 + while (i < len(self.root_list)) and \ + (j < len(other_heap.root_list)): + new_tree = None + while self.root_list[i] is None: + i += 1 + while other_heap.root_list[j] is None: + j += 1 + if self.root_list[i].order == other_heap.root_list[j].order: + new_tree = self.merge_tree(self.root_list[i], + other_heap.root_list[j]) + i += 1 + j += 1 + else: + if self.root_list[i].order < other_heap.root_list[j].order: + new_tree = self.root_list[i] + i += 1 + else: + new_tree = other_heap.root_list[j] + j += 1 + self._merge_heap_last_new_tree(new_root_list, new_tree) + + while i < len(self.root_list): + new_tree = self.root_list[i] + self._merge_heap_last_new_tree(new_root_list, new_tree) + i += 1 + while j < len(other_heap.root_list): + new_tree = other_heap.root_list[j] + self._merge_heap_last_new_tree(new_root_list, new_tree) + j += 1 + self.root_list = new_root_list + + def insert(self, key, data=None): + """ + Inserts new node with the given key and data. + + key + The key of the node which can be operated + upon by relational operators. + + data + The data to be stored in the new node. + """ + new_node = BinomialTreeNode(key, data) + new_tree = BinomialTree(root=new_node, order=0) + new_heap = BinomialHeap(root_list=[new_tree]) + self.merge(new_heap) + + def find_minimum(self, **kwargs): + """ + Finds the node with the minimum key. + + Returns + ======= + + min_node: BinomialTreeNode + """ + if self.is_empty: + raise IndexError("Binomial heap is empty.") + min_node = None + idx, min_idx = 0, None + for tree in self.root_list: + if ((min_node is None) or + (tree is not None and tree.root is not None and + min_node.key > tree.root.key)): + min_node = tree.root + min_idx = idx + idx += 1 + if kwargs.get('get_index', None) is not None: + return min_node, min_idx + return min_node + + def delete_minimum(self): + """ + Deletes the node with minimum key. + """ + min_node, min_idx = self.find_minimum(get_index=True) + child_root_list = [] + for k, child in enumerate(min_node.children): + if child is not None: + child_root_list.append(BinomialTree(root=child, order=k)) + self.root_list.remove(self.root_list[min_idx]) + child_heap = BinomialHeap(root_list=child_root_list) + self.merge(child_heap) + + @property + def is_empty(self): + return not self.root_list + + def decrease_key(self, node, new_key): + """ + Decreases the key of the given node. + + Parameters + ========== + + node: BinomialTreeNode + The node whose key is to be reduced. + new_key + The new key of the given node, + should be less than the current key. + """ + if node.key <= new_key: + raise ValueError("The new key " + "should be less than current node's key.") + node.key = new_key + while ((not node.is_root) and + (node.parent.key > node.key)): + node.parent.key, node.key = \ + node.key, node.parent.key + node.parent.data, node.data = \ + node.data, node.parent.data + node = node.parent + + def delete(self, node): + """ + Deletes the given node. + + Parameters + ========== + + node: BinomialTreeNode + The node which is to be deleted. + """ + self.decrease_key(node, self.find_minimum().key - 1) + self.delete_minimum() diff --git a/pydatastructs/trees/tests/test_heaps.py b/pydatastructs/trees/tests/test_heaps.py index dece2f132..58529b19e 100644 --- a/pydatastructs/trees/tests/test_heaps.py +++ b/pydatastructs/trees/tests/test_heaps.py @@ -1,236 +1,236 @@ -from pydatastructs.trees.heaps import BinaryHeap, TernaryHeap, BinomialHeap, DHeap -from pydatastructs.linear_data_structures.arrays import DynamicOneDimensionalArray -from pydatastructs.miscellaneous_data_structures.binomial_trees import BinomialTree -from pydatastructs.utils.misc_util import TreeNode, BinomialTreeNode -from pydatastructs.utils.raises_util import raises -from collections import deque as Queue - -def test_BinaryHeap(): - - max_heap = BinaryHeap(heap_property="max") - - assert raises(IndexError, lambda: max_heap.extract()) - - max_heap.insert(100, 100) - max_heap.insert(19, 19) - max_heap.insert(36, 36) - max_heap.insert(17, 17) - max_heap.insert(3, 3) - max_heap.insert(25, 25) - max_heap.insert(1, 1) - max_heap.insert(2, 2) - max_heap.insert(7, 7) - assert str(max_heap) == \ - ("[(100, 100, [1, 2]), (19, 19, [3, 4]), " - "(36, 36, [5, 6]), (17, 17, [7, 8]), " - "(3, 3, []), (25, 25, []), (1, 1, []), " - "(2, 2, []), (7, 7, [])]") - - assert max_heap.extract().key == 100 - - expected_sorted_elements = [36, 25, 19, 17, 7, 3, 2, 1] - l = max_heap.heap[0].left - l = max_heap.heap[0].right - sorted_elements = [] - for _ in range(8): - sorted_elements.append(max_heap.extract().key) - assert expected_sorted_elements == sorted_elements - - elements = [ - TreeNode(7, 7), TreeNode(25, 25), TreeNode(100, 100), - TreeNode(1, 1), TreeNode(2, 2), TreeNode(3, 3), - TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) - ] - min_heap = BinaryHeap(elements=elements, heap_property="min") - assert min_heap.extract().key == 1 - - expected_sorted_elements = [2, 3, 7, 17, 19, 25, 36, 100] - sorted_elements = [min_heap.extract().key for _ in range(8)] - assert expected_sorted_elements == sorted_elements - - non_TreeNode_elements = [ - (7, 7), TreeNode(25, 25), TreeNode(100, 100), - TreeNode(1, 1), (2, 2), TreeNode(3, 3), - TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) - ] - assert raises(TypeError, lambda: - BinaryHeap(elements = non_TreeNode_elements, heap_property='min')) - - non_TreeNode_elements = DynamicOneDimensionalArray(int, 0) - non_TreeNode_elements.append(1) - non_TreeNode_elements.append(2) - assert raises(TypeError, lambda: - BinaryHeap(elements = non_TreeNode_elements, heap_property='min')) - - non_heapable = "[1, 2, 3]" - assert raises(ValueError, lambda: - BinaryHeap(elements = non_heapable, heap_property='min')) - -def test_TernaryHeap(): - max_heap = TernaryHeap(heap_property="max") - assert raises(IndexError, lambda: max_heap.extract()) - max_heap.insert(100, 100) - max_heap.insert(19, 19) - max_heap.insert(36, 36) - max_heap.insert(17, 17) - max_heap.insert(3, 3) - max_heap.insert(25, 25) - max_heap.insert(1, 1) - max_heap.insert(2, 2) - max_heap.insert(7, 7) - assert str(max_heap) == \ - ('[(100, 100, [1, 2, 3]), (25, 25, [4, 5, 6]), ' - '(36, 36, [7, 8]), (17, 17, []), ' - '(3, 3, []), (19, 19, []), (1, 1, []), ' - '(2, 2, []), (7, 7, [])]') - - assert max_heap.extract().key == 100 - - expected_sorted_elements = [36, 25, 19, 17, 7, 3, 2, 1] - sorted_elements = [] - for _ in range(8): - sorted_elements.append(max_heap.extract().key) - assert expected_sorted_elements == sorted_elements - - elements = [ - TreeNode(7, 7), TreeNode(25, 25), TreeNode(100, 100), - TreeNode(1, 1), TreeNode(2, 2), TreeNode(3, 3), - TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) - ] - min_heap = TernaryHeap(elements=elements, heap_property="min") - expected_extracted_element = min_heap.heap[0].key - assert min_heap.extract().key == expected_extracted_element - - expected_sorted_elements = [2, 3, 7, 17, 19, 25, 36, 100] - sorted_elements = [min_heap.extract().key for _ in range(8)] - assert expected_sorted_elements == sorted_elements - -def test_DHeap(): - assert raises(ValueError, lambda: DHeap(heap_property="none", d=4)) - max_heap = DHeap(heap_property="max", d=5) - assert raises(IndexError, lambda: max_heap.extract()) - max_heap.insert(100, 100) - max_heap.insert(19, 19) - max_heap.insert(36, 36) - max_heap.insert(17, 17) - max_heap.insert(3, 3) - max_heap.insert(25, 25) - max_heap.insert(1, 1) - max_heap = DHeap(max_heap.heap, heap_property="max", d=4) - max_heap.insert(2, 2) - max_heap.insert(7, 7) - assert str(max_heap) == \ - ('[(100, 100, [1, 2, 3, 4]), (25, 25, [5, 6, 7, 8]), ' - '(36, 36, []), (17, 17, []), (3, 3, []), (19, 19, []), ' - '(1, 1, []), (2, 2, []), (7, 7, [])]') - - assert max_heap.extract().key == 100 - - expected_sorted_elements = [36, 25, 19, 17, 7, 3, 2, 1] - sorted_elements = [] - for _ in range(8): - sorted_elements.append(max_heap.extract().key) - assert expected_sorted_elements == sorted_elements - - elements = [ - TreeNode(7, 7), TreeNode(25, 25), TreeNode(100, 100), - TreeNode(1, 1), TreeNode(2, 2), TreeNode(3, 3), - TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) - ] - min_heap = DHeap(elements=DynamicOneDimensionalArray(TreeNode, 9, elements), heap_property="min") - assert min_heap.extract().key == 1 - - expected_sorted_elements = [2, 3, 7, 17, 19, 25, 36, 100] - sorted_elements = [min_heap.extract().key for _ in range(8)] - assert expected_sorted_elements == sorted_elements - -def test_BinomialHeap(): - - # Corner cases - assert raises(TypeError, lambda: - BinomialHeap( - root_list=[BinomialTreeNode(1, 1), None]) - ) is True - tree1 = BinomialTree(BinomialTreeNode(1, 1), 0) - tree2 = BinomialTree(BinomialTreeNode(2, 2), 0) - bh = BinomialHeap(root_list=[tree1, tree2]) - assert raises(TypeError, lambda: - bh.merge_tree(BinomialTreeNode(2, 2), None)) - assert raises(TypeError, lambda: - bh.merge(None)) - - # Testing BinomialHeap.merge - nodes = [BinomialTreeNode(1, 1), # 0 - BinomialTreeNode(3, 3), # 1 - BinomialTreeNode(9, 9), # 2 - BinomialTreeNode(11, 11), # 3 - BinomialTreeNode(6, 6), # 4 - BinomialTreeNode(14, 14), # 5 - BinomialTreeNode(2, 2), # 6 - BinomialTreeNode(7, 7), # 7 - BinomialTreeNode(4, 4), # 8 - BinomialTreeNode(8, 8), # 9 - BinomialTreeNode(12, 12), # 10 - BinomialTreeNode(10, 10), # 11 - BinomialTreeNode(5, 5), # 12 - BinomialTreeNode(21, 21)] # 13 - - nodes[2].add_children(nodes[3]) - nodes[4].add_children(nodes[5]) - nodes[6].add_children(nodes[9], nodes[8], nodes[7]) - nodes[7].add_children(nodes[11], nodes[10]) - nodes[8].add_children(nodes[12]) - nodes[10].add_children(nodes[13]) - - tree11 = BinomialTree(nodes[0], 0) - tree12 = BinomialTree(nodes[2], 1) - tree13 = BinomialTree(nodes[6], 3) - tree21 = BinomialTree(nodes[1], 0) - - heap1 = BinomialHeap(root_list=[tree11, tree12, tree13]) - heap2 = BinomialHeap(root_list=[tree21]) - - def bfs(heap): - bfs_trav = [] - for i in range(len(heap.root_list)): - layer = [] - bfs_q = Queue() - bfs_q.append(heap.root_list[i].root) - while len(bfs_q) != 0: - curr_node = bfs_q.popleft() - if curr_node is not None: - layer.append(curr_node.key) - for _i in range(curr_node.children._last_pos_filled + 1): - bfs_q.append(curr_node.children[_i]) - if layer != []: - bfs_trav.append(layer) - return bfs_trav - - heap1.merge(heap2) - expected_bfs_trav = [[1, 3, 9, 11], [2, 8, 4, 7, 5, 10, 12, 21]] - assert bfs(heap1) == expected_bfs_trav - - # Testing Binomial.find_minimum - assert heap1.find_minimum().key == 1 - - # Testing Binomial.delete_minimum - heap1.delete_minimum() - assert bfs(heap1) == [[3], [9, 11], [2, 8, 4, 7, 5, 10, 12, 21]] - assert raises(ValueError, lambda: heap1.decrease_key(nodes[3], 15)) - heap1.decrease_key(nodes[3], 0) - assert bfs(heap1) == [[3], [0, 9], [2, 8, 4, 7, 5, 10, 12, 21]] - heap1.delete(nodes[12]) - assert bfs(heap1) == [[3, 8], [0, 9, 2, 7, 4, 10, 12, 21]] - - # Testing BinomialHeap.insert - heap = BinomialHeap() - assert raises(IndexError, lambda: heap.find_minimum()) - heap.insert(1, 1) - heap.insert(3, 3) - heap.insert(6, 6) - heap.insert(9, 9) - heap.insert(14, 14) - heap.insert(11, 11) - heap.insert(2, 2) - heap.insert(7, 7) - assert bfs(heap) == [[1, 3, 6, 2, 9, 7, 11, 14]] +from pydatastructs.trees.heaps import BinaryHeap, TernaryHeap, BinomialHeap, DHeap +from pydatastructs.linear_data_structures.arrays import DynamicOneDimensionalArray +from pydatastructs.miscellaneous_data_structures.binomial_trees import BinomialTree +from pydatastructs.utils.misc_util import TreeNode, BinomialTreeNode +from pydatastructs.utils.raises_util import raises +from collections import deque as Queue + +def test_BinaryHeap(): + + max_heap = BinaryHeap(heap_property="max") + + assert raises(IndexError, lambda: max_heap.extract()) + + max_heap.insert(100, 100) + max_heap.insert(19, 19) + max_heap.insert(36, 36) + max_heap.insert(17, 17) + max_heap.insert(3, 3) + max_heap.insert(25, 25) + max_heap.insert(1, 1) + max_heap.insert(2, 2) + max_heap.insert(7, 7) + assert str(max_heap) == \ + ("[(100, 100, [1, 2]), (19, 19, [3, 4]), " + "(36, 36, [5, 6]), (17, 17, [7, 8]), " + "(3, 3, []), (25, 25, []), (1, 1, []), " + "(2, 2, []), (7, 7, [])]") + + assert max_heap.extract().key == 100 + + expected_sorted_elements = [36, 25, 19, 17, 7, 3, 2, 1] + l = max_heap.heap[0].left + l = max_heap.heap[0].right + sorted_elements = [] + for _ in range(8): + sorted_elements.append(max_heap.extract().key) + assert expected_sorted_elements == sorted_elements + + elements = [ + TreeNode(7, 7), TreeNode(25, 25), TreeNode(100, 100), + TreeNode(1, 1), TreeNode(2, 2), TreeNode(3, 3), + TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) + ] + min_heap = BinaryHeap(elements=elements, heap_property="min") + assert min_heap.extract().key == 1 + + expected_sorted_elements = [2, 3, 7, 17, 19, 25, 36, 100] + sorted_elements = [min_heap.extract().key for _ in range(8)] + assert expected_sorted_elements == sorted_elements + + non_TreeNode_elements = [ + (7, 7), TreeNode(25, 25), TreeNode(100, 100), + TreeNode(1, 1), (2, 2), TreeNode(3, 3), + TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) + ] + assert raises(TypeError, lambda: + BinaryHeap(elements = non_TreeNode_elements, heap_property='min')) + + non_TreeNode_elements = DynamicOneDimensionalArray(int, 0) + non_TreeNode_elements.append(1) + non_TreeNode_elements.append(2) + assert raises(TypeError, lambda: + BinaryHeap(elements = non_TreeNode_elements, heap_property='min')) + + non_heapable = "[1, 2, 3]" + assert raises(ValueError, lambda: + BinaryHeap(elements = non_heapable, heap_property='min')) + +def test_TernaryHeap(): + max_heap = TernaryHeap(heap_property="max") + assert raises(IndexError, lambda: max_heap.extract()) + max_heap.insert(100, 100) + max_heap.insert(19, 19) + max_heap.insert(36, 36) + max_heap.insert(17, 17) + max_heap.insert(3, 3) + max_heap.insert(25, 25) + max_heap.insert(1, 1) + max_heap.insert(2, 2) + max_heap.insert(7, 7) + assert str(max_heap) == \ + ('[(100, 100, [1, 2, 3]), (25, 25, [4, 5, 6]), ' + '(36, 36, [7, 8]), (17, 17, []), ' + '(3, 3, []), (19, 19, []), (1, 1, []), ' + '(2, 2, []), (7, 7, [])]') + + assert max_heap.extract().key == 100 + + expected_sorted_elements = [36, 25, 19, 17, 7, 3, 2, 1] + sorted_elements = [] + for _ in range(8): + sorted_elements.append(max_heap.extract().key) + assert expected_sorted_elements == sorted_elements + + elements = [ + TreeNode(7, 7), TreeNode(25, 25), TreeNode(100, 100), + TreeNode(1, 1), TreeNode(2, 2), TreeNode(3, 3), + TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) + ] + min_heap = TernaryHeap(elements=elements, heap_property="min") + expected_extracted_element = min_heap.heap[0].key + assert min_heap.extract().key == expected_extracted_element + + expected_sorted_elements = [2, 3, 7, 17, 19, 25, 36, 100] + sorted_elements = [min_heap.extract().key for _ in range(8)] + assert expected_sorted_elements == sorted_elements + +def test_DHeap(): + assert raises(ValueError, lambda: DHeap(heap_property="none", d=4)) + max_heap = DHeap(heap_property="max", d=5) + assert raises(IndexError, lambda: max_heap.extract()) + max_heap.insert(100, 100) + max_heap.insert(19, 19) + max_heap.insert(36, 36) + max_heap.insert(17, 17) + max_heap.insert(3, 3) + max_heap.insert(25, 25) + max_heap.insert(1, 1) + max_heap = DHeap(max_heap.heap, heap_property="max", d=4) + max_heap.insert(2, 2) + max_heap.insert(7, 7) + assert str(max_heap) == \ + ('[(100, 100, [1, 2, 3, 4]), (25, 25, [5, 6, 7, 8]), ' + '(36, 36, []), (17, 17, []), (3, 3, []), (19, 19, []), ' + '(1, 1, []), (2, 2, []), (7, 7, [])]') + + assert max_heap.extract().key == 100 + + expected_sorted_elements = [36, 25, 19, 17, 7, 3, 2, 1] + sorted_elements = [] + for _ in range(8): + sorted_elements.append(max_heap.extract().key) + assert expected_sorted_elements == sorted_elements + + elements = [ + TreeNode(7, 7), TreeNode(25, 25), TreeNode(100, 100), + TreeNode(1, 1), TreeNode(2, 2), TreeNode(3, 3), + TreeNode(17, 17), TreeNode(19, 19), TreeNode(36, 36) + ] + min_heap = DHeap(elements=DynamicOneDimensionalArray(TreeNode, 9, elements), heap_property="min") + assert min_heap.extract().key == 1 + + expected_sorted_elements = [2, 3, 7, 17, 19, 25, 36, 100] + sorted_elements = [min_heap.extract().key for _ in range(8)] + assert expected_sorted_elements == sorted_elements + +def test_BinomialHeap(): + + # Corner cases + assert raises(TypeError, lambda: + BinomialHeap( + root_list=[BinomialTreeNode(1, 1), None]) + ) is True + tree1 = BinomialTree(BinomialTreeNode(1, 1), 0) + tree2 = BinomialTree(BinomialTreeNode(2, 2), 0) + bh = BinomialHeap(root_list=[tree1, tree2]) + assert raises(TypeError, lambda: + bh.merge_tree(BinomialTreeNode(2, 2), None)) + assert raises(TypeError, lambda: + bh.merge(None)) + + # Testing BinomialHeap.merge + nodes = [BinomialTreeNode(1, 1), # 0 + BinomialTreeNode(3, 3), # 1 + BinomialTreeNode(9, 9), # 2 + BinomialTreeNode(11, 11), # 3 + BinomialTreeNode(6, 6), # 4 + BinomialTreeNode(14, 14), # 5 + BinomialTreeNode(2, 2), # 6 + BinomialTreeNode(7, 7), # 7 + BinomialTreeNode(4, 4), # 8 + BinomialTreeNode(8, 8), # 9 + BinomialTreeNode(12, 12), # 10 + BinomialTreeNode(10, 10), # 11 + BinomialTreeNode(5, 5), # 12 + BinomialTreeNode(21, 21)] # 13 + + nodes[2].add_children(nodes[3]) + nodes[4].add_children(nodes[5]) + nodes[6].add_children(nodes[9], nodes[8], nodes[7]) + nodes[7].add_children(nodes[11], nodes[10]) + nodes[8].add_children(nodes[12]) + nodes[10].add_children(nodes[13]) + + tree11 = BinomialTree(nodes[0], 0) + tree12 = BinomialTree(nodes[2], 1) + tree13 = BinomialTree(nodes[6], 3) + tree21 = BinomialTree(nodes[1], 0) + + heap1 = BinomialHeap(root_list=[tree11, tree12, tree13]) + heap2 = BinomialHeap(root_list=[tree21]) + + def bfs(heap): + bfs_trav = [] + for i in range(len(heap.root_list)): + layer = [] + bfs_q = Queue() + bfs_q.append(heap.root_list[i].root) + while len(bfs_q) != 0: + curr_node = bfs_q.popleft() + if curr_node is not None: + layer.append(curr_node.key) + for _i in range(curr_node.children._last_pos_filled + 1): + bfs_q.append(curr_node.children[_i]) + if layer != []: + bfs_trav.append(layer) + return bfs_trav + + heap1.merge(heap2) + expected_bfs_trav = [[1, 3, 9, 11], [2, 8, 4, 7, 5, 10, 12, 21]] + assert bfs(heap1) == expected_bfs_trav + + # Testing Binomial.find_minimum + assert heap1.find_minimum().key == 1 + + # Testing Binomial.delete_minimum + heap1.delete_minimum() + assert bfs(heap1) == [[3], [9, 11], [2, 8, 4, 7, 5, 10, 12, 21]] + assert raises(ValueError, lambda: heap1.decrease_key(nodes[3], 15)) + heap1.decrease_key(nodes[3], 0) + assert bfs(heap1) == [[3], [0, 9], [2, 8, 4, 7, 5, 10, 12, 21]] + heap1.delete(nodes[12]) + assert bfs(heap1) == [[3, 8], [0, 9, 2, 7, 4, 10, 12, 21]] + + # Testing BinomialHeap.insert + heap = BinomialHeap() + assert raises(IndexError, lambda: heap.find_minimum()) + heap.insert(1, 1) + heap.insert(3, 3) + heap.insert(6, 6) + heap.insert(9, 9) + heap.insert(14, 14) + heap.insert(11, 11) + heap.insert(2, 2) + heap.insert(7, 7) + assert bfs(heap) == [[1, 3, 6, 2, 9, 7, 11, 14]] From 55b90b41b2772bd764c9a348752acb41f75ef3be Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 14:53:56 -0700 Subject: [PATCH 29/35] Fix IndexError in ChaCha20 quarter-round by correcting 2D array indexing and documenting the algorithm --- pydatastructs/crypto/ChaCha20.py | 89 ++++++++++++++++++++++++-------- 1 file changed, 67 insertions(+), 22 deletions(-) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index 0b7cade4b..3e00c491e 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -40,28 +40,73 @@ def __repr__(self): """Returns a string representation of the object for debugging.""" return f"" - def _quarter_round(self, state: List[int], a: int, b: int, c: int, d: int): - state[a] = (state[a] + state[b]) % (2**32) - state[d] ^= state[a] - state[d] = ((state[d] << 16) | (state[d] >> 16)) % (2**32) - state[c] = (state[c] + state[d]) % (2**32) - state[b] ^= state[c] - state[b] = ((state[b] << 12) | (state[b] >> 20)) % (2**32) - state[a] = (state[a] + state[b]) % (2**32) - state[d] ^= state[a] - state[d] = ((state[d] << 8) | (state[d] >> 24)) % (2**32) - state[c] = (state[c] + state[d]) % (2**32) - state[b] ^= state[c] - state[b] = ((state[b] << 7) | (state[b] >> 25)) % (2**32) - def _double_round(self, state: List[int]): - self._quarter_round(state, 0, 4, 8, 12) - self._quarter_round(state, 1, 5, 9, 13) - self._quarter_round(state, 2, 6, 10, 14) - self._quarter_round(state, 3, 7, 11, 15) - self._quarter_round(state, 0, 5, 10, 15) - self._quarter_round(state, 1, 6, 11, 12) - self._quarter_round(state, 2, 7, 8, 13) - self._quarter_round(state, 3, 4, 9, 14) + + def _quarter_round(self, state: np.ndarray, a: tuple, b: tuple, c: tuple, d: tuple): + + """ + Performs the ChaCha20 quarter-round operation on the 4x4 state matrix. + + The quarter-round consists of four operations (Add, XOR, and Rotate) performed on + four elements of the state. It is a core component of the ChaCha20 algorithm, ensuring + diffusion of bits for cryptographic security. + + Parameters: + ----------- + state : np.ndarray + A 4x4 matrix (NumPy array) representing the ChaCha20 state. + + a, b, c, d : tuple + Each tuple represents the (row, column) indices of four elements in the state matrix + to be processed in the quarter-round. + + Operations: + ----------- + - Add: Adds two values modulo 2^32. + - XOR: Performs a bitwise XOR operation. + - Rotate: Rotates bits (circular shift) to the left. + + Formula for the quarter-round (performed four times): + ----------------------------------------------------- + 1. a += b; d ^= a; d <<<= 16 + 2. c += d; b ^= c; b <<<= 12 + 3. a += b; d ^= a; d <<<= 8 + 4. c += d; b ^= c; b <<<= 7 + + """ + ax, ay = a + bx, by = b + cx, cy = c + dx, dy = d + + state[ax, ay] = (state[ax, ay] + state[bx, by]) % (2**32) + state[dx, dy] ^= state[ax, ay] + state[dx, dy] = ((state[dx, dy] << 16) | (state[dx, dy] >> 16)) % (2**32) + + state[cx, cy] = (state[cx, cy] + state[dx, dy]) % (2**32) + state[bx, by] ^= state[cx, cy] + state[bx, by] = ((state[bx, by] << 12) | (state[bx, by] >> 20)) % (2**32) + + state[ax, ay] = (state[ax, ay] + state[bx, by]) % (2**32) + state[dx, dy] ^= state[ax, ay] + state[dx, dy] = ((state[dx, dy] << 8) | (state[dx, dy] >> 24)) % (2**32) + + state[cx, cy] = (state[cx, cy] + state[dx, dy]) % (2**32) + state[bx, by] ^= state[cx, cy] + state[bx, by] = ((state[bx, by] << 7) | (state[bx, by] >> 25)) % (2**32) + + def _double_round(self, state: np.ndarray): + + self._quarter_round(state, (0, 0), (1, 0), (2, 0), (3, 0)) + self._quarter_round(state, (0, 1), (1, 1), (2, 1), (3, 1)) + self._quarter_round(state, (0, 2), (1, 2), (2, 2), (3, 2)) + self._quarter_round(state, (0, 3), (1, 3), (2, 3), (3, 3)) + + self._quarter_round(state, (0, 0), (1, 1), (2, 2), (3, 3)) + self._quarter_round(state, (0, 1), (1, 2), (2, 3), (3, 0)) + self._quarter_round(state, (0, 2), (1, 3), (2, 0), (3, 1)) + self._quarter_round(state, (0, 3), (1, 0), (2, 1), (3, 2)) + + def _chacha20_block(self, counter: int) -> bytes: """ Generates a 64-byte keystream block from 16-word (512-bit) state From 2858fff17e53f3d1f6a672cc7522b520fcf6df12 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 14:59:01 -0700 Subject: [PATCH 30/35] Remove trailing whitespace --- pydatastructs/crypto/ChaCha20.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index 3e00c491e..ecd9aefd1 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -40,9 +40,9 @@ def __repr__(self): """Returns a string representation of the object for debugging.""" return f"" - + def _quarter_round(self, state: np.ndarray, a: tuple, b: tuple, c: tuple, d: tuple): - + """ Performs the ChaCha20 quarter-round operation on the 4x4 state matrix. @@ -54,7 +54,7 @@ def _quarter_round(self, state: np.ndarray, a: tuple, b: tuple, c: tuple, d: tup ----------- state : np.ndarray A 4x4 matrix (NumPy array) representing the ChaCha20 state. - + a, b, c, d : tuple Each tuple represents the (row, column) indices of four elements in the state matrix to be processed in the quarter-round. @@ -71,31 +71,31 @@ def _quarter_round(self, state: np.ndarray, a: tuple, b: tuple, c: tuple, d: tup 2. c += d; b ^= c; b <<<= 12 3. a += b; d ^= a; d <<<= 8 4. c += d; b ^= c; b <<<= 7 - + """ ax, ay = a bx, by = b cx, cy = c dx, dy = d - + state[ax, ay] = (state[ax, ay] + state[bx, by]) % (2**32) state[dx, dy] ^= state[ax, ay] state[dx, dy] = ((state[dx, dy] << 16) | (state[dx, dy] >> 16)) % (2**32) - + state[cx, cy] = (state[cx, cy] + state[dx, dy]) % (2**32) state[bx, by] ^= state[cx, cy] state[bx, by] = ((state[bx, by] << 12) | (state[bx, by] >> 20)) % (2**32) - + state[ax, ay] = (state[ax, ay] + state[bx, by]) % (2**32) state[dx, dy] ^= state[ax, ay] state[dx, dy] = ((state[dx, dy] << 8) | (state[dx, dy] >> 24)) % (2**32) - + state[cx, cy] = (state[cx, cy] + state[dx, dy]) % (2**32) state[bx, by] ^= state[cx, cy] state[bx, by] = ((state[bx, by] << 7) | (state[bx, by] >> 25)) % (2**32) - + def _double_round(self, state: np.ndarray): - + self._quarter_round(state, (0, 0), (1, 0), (2, 0), (3, 0)) self._quarter_round(state, (0, 1), (1, 1), (2, 1), (3, 1)) self._quarter_round(state, (0, 2), (1, 2), (2, 2), (3, 2)) @@ -105,7 +105,7 @@ def _double_round(self, state: np.ndarray): self._quarter_round(state, (0, 1), (1, 2), (2, 3), (3, 0)) self._quarter_round(state, (0, 2), (1, 3), (2, 0), (3, 1)) self._quarter_round(state, (0, 3), (1, 0), (2, 1), (3, 2)) - + def _chacha20_block(self, counter: int) -> bytes: """ From 1b82c3dc9405d67db39c411f1824e55a56082125 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 15:02:14 -0700 Subject: [PATCH 31/35] Fix OverflowError in ChaCha20 by using NumPy in-place operations --- pydatastructs/crypto/ChaCha20.py | 34 +++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index ecd9aefd1..6aae2ed8a 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -78,21 +78,33 @@ def _quarter_round(self, state: np.ndarray, a: tuple, b: tuple, c: tuple, d: tup cx, cy = c dx, dy = d - state[ax, ay] = (state[ax, ay] + state[bx, by]) % (2**32) + state[ax, ay] += state[bx, by] state[dx, dy] ^= state[ax, ay] - state[dx, dy] = ((state[dx, dy] << 16) | (state[dx, dy] >> 16)) % (2**32) + state[dx, dy] = np.bitwise_or( + np.left_shift(state[dx, dy], 16), + np.right_shift(state[dx, dy], 16) + ) - state[cx, cy] = (state[cx, cy] + state[dx, dy]) % (2**32) + state[cx, cy] += state[dx, dy] state[bx, by] ^= state[cx, cy] - state[bx, by] = ((state[bx, by] << 12) | (state[bx, by] >> 20)) % (2**32) - - state[ax, ay] = (state[ax, ay] + state[bx, by]) % (2**32) - state[dx, dy] ^= state[ax, ay] - state[dx, dy] = ((state[dx, dy] << 8) | (state[dx, dy] >> 24)) % (2**32) - - state[cx, cy] = (state[cx, cy] + state[dx, dy]) % (2**32) + state[bx, by] = np.bitwise_or( + np.left_shift(state[bx, by], 12), + np.right_shift(state[bx, by], 20) + ) + + state[ax, ay] += state[bx, by] + state[dx, dy] ^= state[ax, ay] + state[dx, dy] = np.bitwise_or( + np.left_shift(state[dx, dy], 8), + np.right_shift(state[dx, dy], 24) + ) + + state[cx, cy] += state[dx, dy] state[bx, by] ^= state[cx, cy] - state[bx, by] = ((state[bx, by] << 7) | (state[bx, by] >> 25)) % (2**32) + state[bx, by] = np.bitwise_or( + np.left_shift(state[bx, by], 7), + np.right_shift(state[bx, by], 25) + ) def _double_round(self, state: np.ndarray): From 4ea0c72ba8a86886439899c4202f5447aa12d3e0 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 15:13:41 -0700 Subject: [PATCH 32/35] Implement explicit np.uint32 conversions and modular addition using & 0xFFFFFFFF --- pydatastructs/crypto/ChaCha20.py | 37 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index 6aae2ed8a..aa7c9f257 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -78,34 +78,33 @@ def _quarter_round(self, state: np.ndarray, a: tuple, b: tuple, c: tuple, d: tup cx, cy = c dx, dy = d - state[ax, ay] += state[bx, by] + state[ax, ay] = ((state[ax, ay].astype(np.uint32) + state[bx, by].astype(np.uint32)) & 0xFFFFFFFF).astype(np.uint32) state[dx, dy] ^= state[ax, ay] state[dx, dy] = np.bitwise_or( - np.left_shift(state[dx, dy], 16), - np.right_shift(state[dx, dy], 16) - ) + np.left_shift(state[dx, dy].astype(np.uint32), 16) & 0xFFFFFFFF, + np.right_shift(state[dx, dy].astype(np.uint32), 16) +) - state[cx, cy] += state[dx, dy] + state[cx, cy] = ((state[cx, cy].astype(np.uint32) + state[dx, dy].astype(np.uint32)) & 0xFFFFFFFF).astype(np.uint32) state[bx, by] ^= state[cx, cy] state[bx, by] = np.bitwise_or( - np.left_shift(state[bx, by], 12), - np.right_shift(state[bx, by], 20) - ) + np.left_shift(state[bx, by].astype(np.uint32), 12) & 0xFFFFFFFF, + np.right_shift(state[bx, by].astype(np.uint32), 20) +) - state[ax, ay] += state[bx, by] - state[dx, dy] ^= state[ax, ay] + state[ax, ay] = ((state[ax, ay].astype(np.uint32) + state[bx, by].astype(np.uint32)) & 0xFFFFFFFF).astype(np.uint32) + state[dx, dy] ^= state[ax, ay] state[dx, dy] = np.bitwise_or( - np.left_shift(state[dx, dy], 8), - np.right_shift(state[dx, dy], 24) - ) + np.left_shift(state[dx, dy].astype(np.uint32), 8) & 0xFFFFFFFF, + np.right_shift(state[dx, dy].astype(np.uint32), 24) +) - state[cx, cy] += state[dx, dy] + state[cx, cy] = ((state[cx, cy].astype(np.uint32) + state[dx, dy].astype(np.uint32)) & 0xFFFFFFFF).astype(np.uint32) state[bx, by] ^= state[cx, cy] state[bx, by] = np.bitwise_or( - np.left_shift(state[bx, by], 7), - np.right_shift(state[bx, by], 25) - ) - + np.left_shift(state[bx, by].astype(np.uint32), 7) & 0xFFFFFFFF, + np.right_shift(state[bx, by].astype(np.uint32), 25) +) def _double_round(self, state: np.ndarray): self._quarter_round(state, (0, 0), (1, 0), (2, 0), (3, 0)) @@ -136,7 +135,7 @@ def _chacha20_block(self, counter: int) -> bytes: working_state = dp(state) for _ in range(10): self._double_round(working_state) - final_state = (working_state + state) % (2**32) + final_state = np.bitwise_and(working_state + state, np.uint32(0xFFFFFFFF)) return struct.pack('<16I', *final_state.flatten()) def _apply_keystream(self, data: bytes) -> bytes: From 24c4e6af50b02f29982b540b75ca1d3d587f6877 Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 15:21:02 -0700 Subject: [PATCH 33/35] Fix typo in decrypt method --- pydatastructs/crypto/ChaCha20.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index aa7c9f257..1f672ddce 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -203,7 +203,7 @@ def decrypt(self, ciphertext: bytes) -> bytes: Returns: bytes: The resulting plaintext. """ - return self.apply_keystream(ciphertext) + return self._apply_keystream(ciphertext) def reset(self, counter: int = 0): """Resets the ChaCha20 counter to the specified value (default is 0).""" From 7f084e02f58814488758f9c275ff92b40a32f1bf Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 15:38:49 -0700 Subject: [PATCH 34/35] Fix ChaCha20 key reuse test by truncating plaintexts to equal length --- pydatastructs/crypto/ChaCha20.py | 3 +++ pydatastructs/crypto/tests/test_chacha20.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index 1f672ddce..661df54d6 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -167,6 +167,7 @@ def _apply_keystream(self, data: bytes) -> bytes: chunk = data[start:start + chunk_size] start += chunk_size keystream = self._chacha20_block(self.counter) + self.counter += 1 xor_block = [] for idx in range(len(chunk)): @@ -188,6 +189,7 @@ def encrypt(self, plaintext: bytes) -> bytes: Returns: bytes: The resulting ciphertext. """ + self.reset(counter=0) return self._apply_keystream(plaintext) def decrypt(self, ciphertext: bytes) -> bytes: @@ -203,6 +205,7 @@ def decrypt(self, ciphertext: bytes) -> bytes: Returns: bytes: The resulting plaintext. """ + self.reset(counter=0) return self._apply_keystream(ciphertext) def reset(self, counter: int = 0): diff --git a/pydatastructs/crypto/tests/test_chacha20.py b/pydatastructs/crypto/tests/test_chacha20.py index 3e4fd188a..36f11cece 100644 --- a/pydatastructs/crypto/tests/test_chacha20.py +++ b/pydatastructs/crypto/tests/test_chacha20.py @@ -99,6 +99,10 @@ def test_key_reuse_simple(): plaintext1 = b"Hello, this is message one!" plaintext2 = b"Hi there, this is message two!" + min_len = min(len(plaintext1), len(plaintext2)) + plaintext1 = plaintext1[:min_len] + plaintext2 = plaintext2[:min_len] + ciphertext1 = cipher1.encrypt(plaintext1) ciphertext2 = cipher2.encrypt(plaintext2) From 717fe8e45d80659eed1d316ac2bf8cb0ba07efbb Mon Sep 17 00:00:00 2001 From: 30215210 Date: Tue, 18 Feb 2025 15:51:43 -0700 Subject: [PATCH 35/35] Remove whitespace --- pydatastructs/crypto/ChaCha20.py | 2 +- pydatastructs/crypto/tests/test_chacha20.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pydatastructs/crypto/ChaCha20.py b/pydatastructs/crypto/ChaCha20.py index 661df54d6..a7965980f 100644 --- a/pydatastructs/crypto/ChaCha20.py +++ b/pydatastructs/crypto/ChaCha20.py @@ -167,7 +167,7 @@ def _apply_keystream(self, data: bytes) -> bytes: chunk = data[start:start + chunk_size] start += chunk_size keystream = self._chacha20_block(self.counter) - + self.counter += 1 xor_block = [] for idx in range(len(chunk)): diff --git a/pydatastructs/crypto/tests/test_chacha20.py b/pydatastructs/crypto/tests/test_chacha20.py index 36f11cece..b605c49fc 100644 --- a/pydatastructs/crypto/tests/test_chacha20.py +++ b/pydatastructs/crypto/tests/test_chacha20.py @@ -102,7 +102,7 @@ def test_key_reuse_simple(): min_len = min(len(plaintext1), len(plaintext2)) plaintext1 = plaintext1[:min_len] plaintext2 = plaintext2[:min_len] - + ciphertext1 = cipher1.encrypt(plaintext1) ciphertext2 = cipher2.encrypt(plaintext2)