From 5cce7ad8b20875a79a2f9653dfaf7c7a0fe3528f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20So=C5=9Bnicki?= Date: Wed, 20 Nov 2024 06:17:36 +0100 Subject: [PATCH] fix: batch byte limit is properly checked on a new step --- .../core/components/aggregating_queue.py | 16 +++++---- tests/unit/test_aggregating_queue.py | 36 +++++++++++++++++++ 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/src/neptune_scale/core/components/aggregating_queue.py b/src/neptune_scale/core/components/aggregating_queue.py index 5611a173..dd0f05ac 100644 --- a/src/neptune_scale/core/components/aggregating_queue.py +++ b/src/neptune_scale/core/components/aggregating_queue.py @@ -93,11 +93,12 @@ def get(self) -> BatchedOperations: if element is None: break - if not batch_operations or element.batch_key != last_batch_key: + if not batch_operations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) batch_operations.append(new_operation) last_batch_key = element.batch_key + batch_bytes += len(element.operation) else: if not element.is_batchable: logger.debug("Batch closed due to next operation not being batchable") @@ -111,21 +112,22 @@ def get(self) -> BatchedOperations: new_operation = RunOperation() new_operation.ParseFromString(element.operation) - merge_run_operation(batch_operations[-1], new_operation) + if element.batch_key != last_batch_key: + batch_operations.append(new_operation) + last_batch_key = element.batch_key + else: + merge_run_operation(batch_operations[-1], new_operation) + batch_bytes += element.metadata_size batch_sequence_id = element.sequence_id batch_timestamp = element.timestamp - if element.metadata_size is not None: - batch_bytes += element.metadata_size - else: - batch_bytes += len(element.operation) elements_in_batch += 1 self.commit() if not element.is_batchable: - logger.debug("Batch closed due to first not being batchable") + logger.debug("Batch closed due to the first element not being batchable") break t1 = time.monotonic() diff --git a/tests/unit/test_aggregating_queue.py b/tests/unit/test_aggregating_queue.py index 978d5d40..723c3302 100644 --- a/tests/unit/test_aggregating_queue.py +++ b/tests/unit/test_aggregating_queue.py @@ -182,6 +182,42 @@ def test__batching(): assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) +@freeze_time("2024-09-01") +def test__queue_element_size_limit_with_different_steps(): + # given + update1 = UpdateRunSnapshot(step=Step(whole=1), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) + update2 = UpdateRunSnapshot(step=Step(whole=2), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) + operation1 = RunOperation(update=update1) + operation2 = RunOperation(update=update2) + element1 = SingleOperation( + sequence_id=1, + timestamp=time.process_time(), + operation=operation1.SerializeToString(), + is_batchable=True, + metadata_size=update1.ByteSize(), + batch_key=1.0, + ) + element2 = SingleOperation( + sequence_id=2, + timestamp=time.process_time(), + operation=operation2.SerializeToString(), + is_batchable=True, + metadata_size=update2.ByteSize(), + batch_key=2.0, + ) + + # and + queue = AggregatingQueue(max_queue_size=2, max_queue_element_size=update1.ByteSize()) + + # when + queue.put_nowait(element=element1) + queue.put_nowait(element=element2) + + # then + assert queue.get() == BatchedOperations(sequence_id=1, timestamp=element1.timestamp, operation=element1.operation) + assert queue.get() == BatchedOperations(sequence_id=2, timestamp=element2.timestamp, operation=element2.operation) + + @freeze_time("2024-09-01") def test__not_merge_two_run_creation(): # given