diff --git a/.github/workflows/tests-e2e.yml b/.github/workflows/tests-e2e.yml index 43d09a05..cdb6f0e9 100644 --- a/.github/workflows/tests-e2e.yml +++ b/.github/workflows/tests-e2e.yml @@ -33,6 +33,7 @@ jobs: pip install -r dev_requirements.txt - name: Run tests + timeout-minutes: 30 env: NEPTUNE_API_TOKEN: ${{ secrets.E2E_API_TOKEN }} NEPTUNE_E2E_PROJECT: ${{ secrets.E2E_PROJECT }} diff --git a/src/neptune_scale/api/attribute.py b/src/neptune_scale/api/attribute.py index 25cfaa20..633c60ef 100644 --- a/src/neptune_scale/api/attribute.py +++ b/src/neptune_scale/api/attribute.py @@ -15,7 +15,8 @@ from neptune_scale.api.metrics import Metrics from neptune_scale.sync.metadata_splitter import MetadataSplitter -from neptune_scale.sync.operations_queue import OperationsQueue +from neptune_scale.sync.operations_repository import OperationsRepository +from neptune_scale.sync.sequence_tracker import SequenceTracker __all__ = ("Attribute", "AttributeStore") @@ -65,11 +66,14 @@ class AttributeStore: end consuming the queue (which would be SyncProcess). """ - def __init__(self, project: str, run_id: str, operations_queue: OperationsQueue) -> None: + def __init__( + self, project: str, run_id: str, operations_repo: OperationsRepository, sequence_tracker: SequenceTracker + ) -> None: self._project = project self._run_id = run_id - self._operations_queue = operations_queue + self._operations_repo = operations_repo self._attributes: dict[str, Attribute] = {} + self._sequence_tracker = sequence_tracker def __getitem__(self, path: str) -> "Attribute": path = cleanup_path(path) @@ -108,9 +112,10 @@ def log( remove_tags=tags_remove, ) - for operation, metadata_size in splitter: - key = metrics.batch_key() if metrics is not None else None - self._operations_queue.enqueue(operation=operation, size=metadata_size, key=key) + operations = list(splitter) + sequence_id = self._operations_repo.save_update_run_snapshots(operations) + + self._sequence_tracker.update_sequence_id(sequence_id) class Attribute: diff --git a/src/neptune_scale/api/metrics.py b/src/neptune_scale/api/metrics.py index eb102da3..ff2e5c7c 100644 --- a/src/neptune_scale/api/metrics.py +++ b/src/neptune_scale/api/metrics.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Hashable from dataclasses import dataclass from typing import ( Optional, @@ -40,6 +39,3 @@ def __post_init__(self) -> None: self.preview_completion = None if self.preview_completion is not None: verify_value_between("preview_completion", self.preview_completion, 0.0, 1.0) - - def batch_key(self) -> Hashable: - return (self.step, self.preview, self.preview_completion) diff --git a/src/neptune_scale/api/run.py b/src/neptune_scale/api/run.py index c0b786bc..9278a842 100644 --- a/src/neptune_scale/api/run.py +++ b/src/neptune_scale/api/run.py @@ -4,6 +4,15 @@ from __future__ import annotations +from pathlib import Path +from types import TracebackType + +from neptune_scale.sync.operations_repository import ( + DB_VERSION, + Metadata, + OperationsRepository, +) + __all__ = ["Run"] import atexit @@ -22,7 +31,6 @@ from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ForkPoint from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun -from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation from neptune_scale.api.attribute import AttributeStore from neptune_scale.api.metrics import Metrics @@ -35,6 +43,8 @@ ) from neptune_scale.exceptions import ( NeptuneApiTokenNotProvided, + NeptuneConflictingDataInLocalStorage, + NeptuneLocalStorageInUnsupportedVersion, NeptuneProjectNotProvided, ) from neptune_scale.net.serialization import ( @@ -46,20 +56,15 @@ ErrorsQueue, ) from neptune_scale.sync.lag_tracking import LagTracker -from neptune_scale.sync.operations_queue import OperationsQueue from neptune_scale.sync.parameters import ( MAX_EXPERIMENT_NAME_LENGTH, - MAX_QUEUE_SIZE, MAX_RUN_ID_LENGTH, MINIMAL_WAIT_FOR_ACK_SLEEP_TIME, MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, STOP_MESSAGE_FREQUENCY, ) +from neptune_scale.sync.sequence_tracker import SequenceTracker from neptune_scale.sync.sync_process import SyncProcess -from neptune_scale.util.abstract import ( - Resource, - WithResources, -) from neptune_scale.util.envs import ( API_TOKEN_ENV_NAME, PROJECT_ENV_NAME, @@ -74,7 +79,7 @@ logger = get_logger() -class Run(WithResources, AbstractContextManager): +class Run(AbstractContextManager): """ Representation of tracked metadata. """ @@ -91,7 +96,7 @@ def __init__( creation_time: Optional[datetime] = None, fork_run_id: Optional[str] = None, fork_step: Optional[Union[int, float]] = None, - max_queue_size: int = MAX_QUEUE_SIZE, + max_queue_size: Optional[int] = None, async_lag_threshold: Optional[float] = None, on_async_lag_callback: Optional[Callable[[], None]] = None, on_queue_full_callback: Optional[Callable[[BaseException, Optional[float]], None]] = None, @@ -114,7 +119,7 @@ def __init__( creation_time: Custom creation time of the run. fork_run_id: If forking from an existing run, ID of the run to fork from. fork_step: If forking from an existing run, step number to fork from. - max_queue_size: Maximum number of operations in a queue. + max_queue_size: Deprecated. async_lag_threshold: Threshold for the duration between the queueing and synchronization of an operation (in seconds). If the duration exceeds the threshold, the callback function is triggered. on_async_lag_callback: Callback function triggered when the duration between the queueing and synchronization @@ -134,7 +139,6 @@ def __init__( verify_type("creation_time", creation_time, (datetime, type(None))) verify_type("fork_run_id", fork_run_id, (str, type(None))) verify_type("fork_step", fork_step, (int, float, type(None))) - verify_type("max_queue_size", max_queue_size, int) verify_type("async_lag_threshold", async_lag_threshold, (int, float, type(None))) verify_type("on_async_lag_callback", on_async_lag_callback, (Callable, type(None))) verify_type("on_queue_full_callback", on_queue_full_callback, (Callable, type(None))) @@ -161,8 +165,8 @@ def __init__( ): raise ValueError("`on_async_lag_callback` must be used with `async_lag_threshold`.") - if max_queue_size < 1: - raise ValueError("`max_queue_size` must be greater than 0.") + if max_queue_size is not None: + logger.warning("`max_queue_size` is deprecated and will be removed in a future version.") project = project or os.environ.get(PROJECT_ENV_NAME) if project: @@ -198,12 +202,24 @@ def __init__( self._run_id: str = run_id self._lock = threading.RLock() - self._operations_queue: OperationsQueue = OperationsQueue( - lock=self._lock, - max_size=max_queue_size, + + operations_repository_path = _create_repository_path(self._project, self._run_id) + self._operations_repo: OperationsRepository = OperationsRepository( + db_path=operations_repository_path, ) + self._operations_repo.init_db() + existing_metadata = self._operations_repo.get_metadata() - self._attr_store: AttributeStore = AttributeStore(self._project, self._run_id, self._operations_queue) + # Save metadata if it doesn't exist + if existing_metadata is None: + self._operations_repo.save_metadata(self._project, self._run_id, fork_run_id, fork_step) + else: + _validate_existing_db(existing_metadata, resume, self._project, self._run_id, fork_run_id, fork_step) + + self._sequence_tracker = SequenceTracker() + self._attr_store: AttributeStore = AttributeStore( + self._project, self._run_id, self._operations_repo, self._sequence_tracker + ) self._errors_queue: ErrorsQueue = ErrorsQueue() self._errors_monitor = ErrorsMonitor( @@ -222,21 +238,20 @@ def __init__( self._sync_process = SyncProcess( project=self._project, family=self._run_id, - operations_queue=self._operations_queue.queue, + operations_repository_path=operations_repository_path, errors_queue=self._errors_queue, process_link=self._process_link, api_token=input_api_token, last_queued_seq=self._last_queued_seq, last_ack_seq=self._last_ack_seq, last_ack_timestamp=self._last_ack_timestamp, - max_queue_size=max_queue_size, mode=mode, ) self._lag_tracker: Optional[LagTracker] = None if async_lag_threshold is not None and on_async_lag_callback is not None: self._lag_tracker = LagTracker( errors_queue=self._errors_queue, - operations_queue=self._operations_queue, + sequence_tracker=self._sequence_tracker, last_ack_timestamp=self._last_ack_timestamp, async_lag_threshold=async_lag_threshold, on_async_lag_callback=on_async_lag_callback, @@ -251,6 +266,7 @@ def __init__( self._exit_func: Optional[Callable[[], None]] = atexit.register(self._close) if not resume: + # Create a new run self._create_run( creation_time=datetime.now() if creation_time is None else creation_time, experiment_name=experiment_name, @@ -266,21 +282,6 @@ def _on_child_link_closed(self, _: ProcessLink) -> None: self._is_closing = True self.terminate() - @property - def resources(self) -> tuple[Resource, ...]: - if self._lag_tracker is not None: - return ( - self._errors_queue, - self._operations_queue, - self._lag_tracker, - self._errors_monitor, - ) - return ( - self._errors_queue, - self._operations_queue, - self._errors_monitor, - ) - def _close(self, *, wait: bool = True) -> None: with self._lock: if self._is_closing: @@ -310,7 +311,8 @@ def _close(self, *, wait: bool = True) -> None: if threading.current_thread() != self._errors_monitor: self._errors_monitor.join() - super().close() + self._operations_repo.close() + self._errors_queue.close() def terminate(self) -> None: """ @@ -363,6 +365,14 @@ def close(self) -> None: self._exit_func = None self._close(wait=True) + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() + def _create_run( self, creation_time: datetime, @@ -376,17 +386,15 @@ def _create_run( parent_project=self._project, parent_run_id=fork_run_id, step=make_step(number=fork_step) ) - operation = RunOperation( - project=self._project, - run_id=self._run_id, - create=CreateRun( - family=self._run_id, - fork_point=fork_point, - experiment_id=experiment_name, - creation_time=None if creation_time is None else datetime_to_proto(creation_time), - ), + create_run = CreateRun( + family=self._run_id, + fork_point=fork_point, + experiment_id=experiment_name, + creation_time=None if creation_time is None else datetime_to_proto(creation_time), ) - self._operations_queue.enqueue(operation=operation) + + sequence = self._operations_repo.save_create_run(create_run) + self._sequence_tracker.update_sequence_id(sequence) def log_metrics( self, @@ -604,20 +612,20 @@ def _wait( # Handle the case where we get notified on `wait_seq` before we actually wait. # Otherwise, we would unnecessarily block, waiting on a notify_all() that never happens. - if wait_seq.value >= self._operations_queue.last_sequence_id: + if wait_seq.value >= self._sequence_tracker.last_sequence_id: break with wait_seq: wait_seq.wait(timeout=wait_time) value = wait_seq.value - last_queued_sequence_id = self._operations_queue.last_sequence_id + last_queued_sequence_id = self._sequence_tracker.last_sequence_id if value == -1: - if self._operations_queue.last_sequence_id != -1: + if self._sequence_tracker.last_sequence_id != -1: last_print_timestamp = print_message( f"Waiting. No operations were {phrase} yet. Operations to sync: %s", - self._operations_queue.last_sequence_id + 1, + self._sequence_tracker.last_sequence_id + 1, last_print=last_print_timestamp, verbose=verbose, ) @@ -692,3 +700,35 @@ def print_message(msg: str, *args: Any, last_print: Optional[float] = None, verb return current_time return last_print + + +def _create_repository_path(project: str, run_id: str) -> Path: + sanitized_project = project.replace("/", "_") + return Path(os.getcwd()) / ".neptune" / f"{sanitized_project}_{run_id}.sqlite3" + + +def _validate_existing_db( + existing_metadata: Metadata, + resume: bool, + project: str, + run_id: str, + fork_run_id: Optional[str], + fork_step: Optional[float], +) -> None: + if existing_metadata.version != DB_VERSION: + raise NeptuneLocalStorageInUnsupportedVersion() + + if existing_metadata.project != project or existing_metadata.run_id != run_id: + # should never happen because we use project and run_id to create the repository path + raise NeptuneConflictingDataInLocalStorage() + + # Check for conflicts when not resuming, because we don't allow fork points in resumed runs + if resume: + return + + if existing_metadata.parent_run_id == fork_run_id and existing_metadata.fork_step == fork_step: + logger.warning("Run already exists in local storage with the same parent run and fork point. Resuming the run.") + return + else: + # Same run_id but different fork points + raise NeptuneConflictingDataInLocalStorage() diff --git a/src/neptune_scale/exceptions.py b/src/neptune_scale/exceptions.py index ef83652d..985687c6 100644 --- a/src/neptune_scale/exceptions.py +++ b/src/neptune_scale/exceptions.py @@ -40,6 +40,8 @@ "NeptuneApiTokenNotProvided", "NeptuneTooManyRequestsResponseError", "NeptunePreviewStepNotAfterLastCommittedStep", + "NeptuneConflictingDataInLocalStorage", + "NeptuneLocalStorageInUnsupportedVersion", ) from typing import Any @@ -493,3 +495,13 @@ class NeptunePreviewStepNotAfterLastCommittedStep(NeptuneScaleError): the last fully committed (complete) update. Once a complete value is recorded, any preview updates must only be added for later steps. Please adjust the order of your updates and try again. """ + + +class NeptuneLocalStorageInUnsupportedVersion(NeptuneScaleError): + message = """The local storage database is in an unsupported version. + This may happen when you try to use a database created with a newer version of Neptune Scale with an older version of the library. + Please either upgrade Neptune Scale to the latest version or create a new local storage database.""" + + +class NeptuneConflictingDataInLocalStorage(NeptuneScaleError): + message = """NeptuneConflictingDataInLocalStorage: Conflicting data found in local storage.""" diff --git a/src/neptune_scale/net/api_client.py b/src/neptune_scale/net/api_client.py index f5f9d286..485d4ebc 100644 --- a/src/neptune_scale/net/api_client.py +++ b/src/neptune_scale/net/api_client.py @@ -71,7 +71,6 @@ NeptuneUnableToAuthenticateError, ) from neptune_scale.sync.parameters import REQUEST_TIMEOUT -from neptune_scale.util.abstract import Resource from neptune_scale.util.envs import ALLOW_SELF_SIGNED_CERTIFICATE from neptune_scale.util.logger import get_logger @@ -122,13 +121,15 @@ def create_auth_api_client( ) -class ApiClient(Resource, abc.ABC): +class ApiClient(abc.ABC): @abc.abstractmethod def submit(self, operation: RunOperation, family: str) -> Response[SubmitResponse]: ... @abc.abstractmethod def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ... + def close(self) -> None: ... + class HostedApiClient(ApiClient): def __init__(self, api_token: str) -> None: diff --git a/src/neptune_scale/sync/aggregating_queue.py b/src/neptune_scale/sync/aggregating_queue.py deleted file mode 100644 index 6283840f..00000000 --- a/src/neptune_scale/sync/aggregating_queue.py +++ /dev/null @@ -1,207 +0,0 @@ -from __future__ import annotations - -__all__ = ("AggregatingQueue",) - -import time -from collections.abc import Hashable -from queue import ( - Empty, - Queue, -) -from typing import Optional - -from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation - -from neptune_scale.sync.parameters import ( - BATCH_WAIT_TIME_SECONDS, - MAX_BATCH_SIZE, - MAX_QUEUE_ELEMENT_SIZE, -) -from neptune_scale.sync.queue_element import ( - BatchedOperations, - SingleOperation, -) -from neptune_scale.util import get_logger -from neptune_scale.util.abstract import Resource - -logger = get_logger() - - -class AggregatingQueue(Resource): - def __init__( - self, - max_queue_size: int, - max_elements_in_batch: int = MAX_BATCH_SIZE, - max_queue_element_size: int = MAX_QUEUE_ELEMENT_SIZE, - wait_time: float = BATCH_WAIT_TIME_SECONDS, - ) -> None: - self._max_queue_size = max_queue_size - self._max_elements_in_batch = max_elements_in_batch - self._max_queue_element_size = max_queue_element_size - self._wait_time = wait_time - - self._queue: Queue[SingleOperation] = Queue(maxsize=max_queue_size) - self._latest_unprocessed: Optional[SingleOperation] = None - - @property - def maxsize(self) -> int: - return self._max_queue_size - - def put_nowait(self, element: SingleOperation) -> None: - self._queue.put_nowait(element) - - def _get_next(self, timeout: float) -> Optional[SingleOperation]: - # We can assume that each of queue elements are less than MAX_QUEUE_ELEMENT_SIZE because of MetadataSplitter. - # We can assume that every queue element has the same project, run id and family - if self._latest_unprocessed is not None: - return self._latest_unprocessed - - try: - self._latest_unprocessed = self._queue.get(timeout=timeout) - return self._latest_unprocessed - except Empty: - return None - - def commit(self) -> None: - self._latest_unprocessed = None - - def get(self) -> BatchedOperations: - start = time.monotonic() - - batch_operations: dict[Hashable, RunOperation] = {} - batch_sequence_id: Optional[int] = None - batch_timestamp: Optional[float] = None - - batch_bytes: int = 0 - elements_in_batch: int = 0 - wait_remaining = self._wait_time - - # Pull operations off the queue until we either reach the maximum size, or - # the specified wait time has passed. This way we maximize the potential batch size. - while True: - t0 = time.monotonic() - - if elements_in_batch >= self._max_elements_in_batch: - logger.debug("Batch closed due to limit of elements in batch %s", elements_in_batch) - break - - element = self._get_next(wait_remaining) - if element is None: - break - - if not batch_operations: - new_operation = RunOperation() - new_operation.ParseFromString(element.operation) - batch_operations[element.batch_key] = new_operation - batch_bytes += len(element.operation) - else: - if not element.is_batchable: - logger.debug("Batch closed due to next operation not being batchable") - break - - assert element.metadata_size is not None # mypy, metadata update always has metadata size - - if batch_bytes + element.metadata_size > self._max_queue_element_size: - logger.debug("Batch closed due to size limit %s", batch_bytes + element.metadata_size) - break - - new_operation = RunOperation() - new_operation.ParseFromString(element.operation) - if element.batch_key not in batch_operations: - batch_operations[element.batch_key] = new_operation - else: - merge_run_operation(batch_operations[element.batch_key], new_operation) - batch_bytes += element.metadata_size - - batch_sequence_id = element.sequence_id - batch_timestamp = element.timestamp - - elements_in_batch += 1 - - self.commit() - - if not element.is_batchable: - logger.debug("Batch closed due to the first element not being batchable") - break - - t1 = time.monotonic() - wait_remaining -= t1 - t0 - - if wait_remaining <= 0: - logger.debug("Batch closed due to wait time") - break - - if not batch_operations: - logger.debug(f"Batch is empty after {self._wait_time} seconds of waiting.") - raise Empty - - assert batch_sequence_id is not None # mypy - assert batch_timestamp is not None # mypy - - logger.debug( - "Batched %d operations. Total size %d. Total time %f", - elements_in_batch, - batch_bytes, - time.monotonic() - start, - ) - - batch = create_run_batch(batch_operations) - - return BatchedOperations( - sequence_id=batch_sequence_id, - timestamp=batch_timestamp, - operation=batch.SerializeToString(), - ) - - -def create_run_batch(operations: dict[Hashable, RunOperation]) -> RunOperation: - if len(operations) == 1: - return next(iter(operations.values())) - - batch = None - for _, operation in sorted(operations.items(), key=lambda x: (x[0] is not None, x[0])): - if batch is None: - batch = RunOperation() - batch.project = operation.project - batch.run_id = operation.run_id - batch.create_missing_project = operation.create_missing_project - batch.api_key = operation.api_key - - operation_type = operation.WhichOneof("operation") - if operation_type == "update": - batch.update_batch.snapshots.append(operation.update) - else: - raise ValueError("Cannot batch operation of type %s", operation_type) - - if batch is None: - raise Empty - return batch - - -def merge_run_operation(batch: RunOperation, operation: RunOperation) -> None: - """ - Merge the `operation` into `batch`, taking into account the special case of `modify_sets`. - - Protobuf merges existing map keys by simply overwriting values, instead of calling - `MergeFrom` on the existing value, eg: A['foo'] = B['foo']. - - We want this instead: - - batch = {'sys/tags': 'string': { 'values': {'foo': ADD}}} - operation = {'sys/tags': 'string': { 'values': {'bar': ADD}}} - result = {'sys/tags': 'string': { 'values': {'foo': ADD, 'bar': ADD}}} - - If we called `batch.MergeFrom(operation)` we would get an overwritten value: - result = {'sys/tags': 'string': { 'values': {'bar': ADD}}} - - This function ensures that the `modify_sets` are merged correctly, leaving the default - behaviour for all other fields. - """ - - modify_sets = operation.update.modify_sets - operation.update.ClearField("modify_sets") - - batch.MergeFrom(operation) - - for k, v in modify_sets.items(): - batch.update.modify_sets[k].MergeFrom(v) diff --git a/src/neptune_scale/sync/errors_tracking.py b/src/neptune_scale/sync/errors_tracking.py index 8bdcca75..fd31c043 100644 --- a/src/neptune_scale/sync/errors_tracking.py +++ b/src/neptune_scale/sync/errors_tracking.py @@ -20,13 +20,12 @@ ) from neptune_scale.sync.parameters import ERRORS_MONITOR_THREAD_SLEEP_TIME from neptune_scale.util import get_logger -from neptune_scale.util.abstract import Resource from neptune_scale.util.daemon import Daemon logger = get_logger() -class ErrorsQueue(Resource): +class ErrorsQueue: def __init__(self) -> None: self._errors_queue: multiprocessing.Queue[BaseException] = multiprocessing.Queue() @@ -60,7 +59,7 @@ def default_warning_callback(error: BaseException, last_seen_at: Optional[float] logger.warning(error) -class ErrorsMonitor(Daemon, Resource): +class ErrorsMonitor(Daemon): def __init__( self, errors_queue: ErrorsQueue, diff --git a/src/neptune_scale/sync/lag_tracking.py b/src/neptune_scale/sync/lag_tracking.py index 76c2bb08..475700e9 100644 --- a/src/neptune_scale/sync/lag_tracking.py +++ b/src/neptune_scale/sync/lag_tracking.py @@ -2,27 +2,26 @@ __all__ = ("LagTracker",) +import time from collections.abc import Callable -from time import monotonic from neptune_scale.sync.errors_tracking import ErrorsQueue -from neptune_scale.sync.operations_queue import OperationsQueue from neptune_scale.sync.parameters import ( LAG_TRACKER_THREAD_SLEEP_TIME, LAG_TRACKER_TIMEOUT, ) +from neptune_scale.sync.sequence_tracker import SequenceTracker from neptune_scale.util import ( Daemon, SharedFloat, ) -from neptune_scale.util.abstract import Resource -class LagTracker(Daemon, Resource): +class LagTracker(Daemon): def __init__( self, errors_queue: ErrorsQueue, - operations_queue: OperationsQueue, + sequence_tracker: SequenceTracker, last_ack_timestamp: SharedFloat, async_lag_threshold: float, on_async_lag_callback: Callable[[], None], @@ -30,7 +29,7 @@ def __init__( super().__init__(name="LagTracker", sleep_time=LAG_TRACKER_THREAD_SLEEP_TIME) self._errors_queue: ErrorsQueue = errors_queue - self._operations_queue: OperationsQueue = operations_queue + self._sequence_tracker: SequenceTracker = sequence_tracker self._last_ack_timestamp: SharedFloat = last_ack_timestamp self._async_lag_threshold: float = async_lag_threshold self._on_async_lag_callback: Callable[[], None] = on_async_lag_callback @@ -41,7 +40,7 @@ def work(self) -> None: with self._last_ack_timestamp: self._last_ack_timestamp.wait(timeout=LAG_TRACKER_TIMEOUT) last_ack_timestamp = self._last_ack_timestamp.value - last_queued_timestamp = self._operations_queue.last_timestamp + last_queued_timestamp = self._sequence_tracker.last_timestamp # No operations were put into the queue if last_queued_timestamp is None: @@ -49,7 +48,7 @@ def work(self) -> None: # No operations were processed by server if last_ack_timestamp < 0 and not self._callback_triggered: - if monotonic() - last_queued_timestamp > self._async_lag_threshold: + if time.time() - last_queued_timestamp > self._async_lag_threshold: self._callback_triggered = True self._on_async_lag_callback() return diff --git a/src/neptune_scale/sync/metadata_splitter.py b/src/neptune_scale/sync/metadata_splitter.py index 80aec6e4..0726df85 100644 --- a/src/neptune_scale/sync/metadata_splitter.py +++ b/src/neptune_scale/sync/metadata_splitter.py @@ -1,5 +1,7 @@ from __future__ import annotations +from neptune_scale.sync.parameters import MAX_SINGLE_OPERATION_SIZE_BYTES + __all__ = ("MetadataSplitter",) import math @@ -46,7 +48,7 @@ T = TypeVar("T", bound=Any) -class MetadataSplitter(Iterator[tuple[RunOperation, int]]): +class MetadataSplitter(Iterator[UpdateRunSnapshot]): def __init__( self, *, @@ -57,7 +59,7 @@ def __init__( metrics: Optional[Metrics], add_tags: Optional[dict[str, Union[list[str], set[str], tuple[str]]]], remove_tags: Optional[dict[str, Union[list[str], set[str], tuple[str]]]], - max_message_bytes_size: int = 1024 * 1024, + max_message_bytes_size: int = MAX_SINGLE_OPERATION_SIZE_BYTES, ): self._should_skip_non_finite_metrics = envs.get_bool(envs.SKIP_NON_FINITE_METRICS, True) self._step = make_step(number=metrics.step) if (metrics is not None and metrics.step is not None) else None @@ -87,7 +89,7 @@ def __iter__(self) -> MetadataSplitter: self._has_returned = False return self - def __next__(self) -> tuple[RunOperation, int]: + def __next__(self) -> UpdateRunSnapshot: update = self._make_empty_update_snapshot() size = update.ByteSize() @@ -116,7 +118,7 @@ def __next__(self) -> tuple[RunOperation, int]: if not self._has_returned or update.assign or update.append or update.modify_sets: self._has_returned = True - return RunOperation(project=self._project, run_id=self._run_id, update=update), size + return update else: raise StopIteration diff --git a/src/neptune_scale/sync/operations_queue.py b/src/neptune_scale/sync/operations_queue.py deleted file mode 100644 index 4dbb1272..00000000 --- a/src/neptune_scale/sync/operations_queue.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -__all__ = ("OperationsQueue",) - -import math -import os -import queue -from collections.abc import Hashable -from multiprocessing import Queue -from time import monotonic -from typing import ( - TYPE_CHECKING, - Optional, -) - -from neptune_scale.api.validation import verify_type -from neptune_scale.exceptions import NeptuneUnableToLogData -from neptune_scale.sync.parameters import ( - MAX_MULTIPROCESSING_QUEUE_SIZE, - MAX_QUEUE_ELEMENT_SIZE, - MAX_QUEUE_SIZE, -) -from neptune_scale.sync.queue_element import SingleOperation -from neptune_scale.util import ( - envs, - get_logger, -) -from neptune_scale.util.abstract import Resource - -if TYPE_CHECKING: - from threading import RLock - - from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation - -logger = get_logger() - -# We use this value when comparing time since the last successful put. This is needed to -# avoid cases where the actual wait time is lower than the `timeout` due to the resolution of the monotonic clock. -# The call to put(block=True, timeout=...) could end fractions of a second earlier than the requested timeout. This -# happens eg. on Windows. -# The next call would then block again unnecessarily, if made quickly enough, even though the previous one failed: -# t=0.000s: successful put to empty queue, last_put_time -> 0.000s -# t=0.010s: failed put with timeout=2s, which blocks and ends early after 1.980s -# t=1.995s: monotonic() - last_put_time == 1.995s, which is still less than 2s -> we would block again, -# but we shouldn't -MONOTONIC_CLOCK_RESOLUTION_UPPER_BOUND = 0.1 - - -class OperationsQueue(Resource): - def __init__( - self, - *, - lock: RLock, - max_size: int = MAX_QUEUE_SIZE, - ) -> None: - verify_type("max_size", max_size, int) - - self._lock: RLock = lock - self._max_size: int = max_size - - self._sequence_id: int = 0 - self._last_timestamp: Optional[float] = None - self._queue: Queue[SingleOperation] = Queue(maxsize=min(MAX_MULTIPROCESSING_QUEUE_SIZE, max_size)) - self._last_successful_put_time = monotonic() - - self._max_blocking_time = envs.get_int(envs.LOG_MAX_BLOCKING_TIME_SECONDS, None) or math.inf - if self._max_blocking_time < 0: - raise ValueError(f"{envs.LOG_MAX_BLOCKING_TIME_SECONDS} must be a non-negative number.") - - action = os.getenv(envs.LOG_FAILURE_ACTION, "drop") - if action not in ("drop", "raise"): - raise ValueError(f"Invalid value '{action}' for {envs.LOG_FAILURE_ACTION}. Must be 'drop' or 'raise'.") - - self._raise_on_enqueue_failure = action == "raise" - - @property - def queue(self) -> Queue[SingleOperation]: - return self._queue - - @property - def last_sequence_id(self) -> int: - with self._lock: - return self._sequence_id - 1 - - @property - def last_timestamp(self) -> Optional[float]: - with self._lock: - return self._last_timestamp - - def enqueue(self, *, operation: RunOperation, size: Optional[int] = None, key: Hashable = None) -> None: - try: - is_metadata_update = operation.HasField("update") - serialized_operation = operation.SerializeToString() - - if len(serialized_operation) > MAX_QUEUE_ELEMENT_SIZE: - raise ValueError(f"Operation size exceeds the maximum allowed size ({MAX_QUEUE_ELEMENT_SIZE})") - - with self._lock: - self._last_timestamp = monotonic() - - item = SingleOperation( - sequence_id=self._sequence_id, - timestamp=self._last_timestamp, - operation=serialized_operation, - metadata_size=size, - is_batchable=is_metadata_update, - batch_key=key, - ) - - # Optimistically put the item without blocking. If the queue is full, we will retry - # the put with blocking, but only if the last successful put was lest than the `timeout` ago. - # This way if the sync process is stuck, we will drop operations until we are able to successfully - # put an item into the queue again after some of the pending items were processed. - try: - self._queue.put_nowait(item) - self._last_successful_put_time = monotonic() - except queue.Full: - elapsed_since_last_put_success = ( - monotonic() - self._last_successful_put_time + MONOTONIC_CLOCK_RESOLUTION_UPPER_BOUND - ) - if elapsed_since_last_put_success < self._max_blocking_time: - try: - self._queue.put(item, block=True, timeout=self._max_blocking_time) - self._last_successful_put_time = monotonic() - except queue.Full: - self._on_enqueue_failed("Operations queue is full", operation) - return - else: - self._on_enqueue_failed("Operations queue is full", operation) - return - - self._sequence_id += 1 - # Pass this through, as it's raised intentionally in _on_enqueue_failed() - except NeptuneUnableToLogData as e: - raise e - except Exception as e: - self._on_enqueue_failed(reason=str(e), operation=operation) - - def close(self) -> None: - self._queue.close() - # This is needed to avoid hanging the main process - self._queue.cancel_join_thread() - - def _on_enqueue_failed(self, reason: str, operation: RunOperation) -> None: - if self._raise_on_enqueue_failure: - raise NeptuneUnableToLogData(reason=reason, operation=str(operation)) - else: - logger.error(f"Dropping operation due to error: {reason}. Operation: {operation}") diff --git a/src/neptune_scale/sync/operations_repository.py b/src/neptune_scale/sync/operations_repository.py index 8fc368c3..3b9ba72b 100644 --- a/src/neptune_scale/sync/operations_repository.py +++ b/src/neptune_scale/sync/operations_repository.py @@ -2,9 +2,12 @@ from pathlib import Path +from neptune_scale.sync.parameters import MAX_SINGLE_OPERATION_SIZE_BYTES + __all__ = ("OperationsRepository", "OperationType", "Operation", "Metadata", "SequenceId") import contextlib +import datetime import os import sqlite3 import threading @@ -43,6 +46,10 @@ class Operation: operation: Union[UpdateRunSnapshot, CreateRun] operation_size_bytes: int + @property + def ts(self) -> datetime.datetime: + return datetime.datetime.fromtimestamp(self.timestamp / 1000) + @dataclass(frozen=True) class Metadata: @@ -111,6 +118,10 @@ def save_update_run_snapshots(self, ops: list[UpdateRunSnapshot]) -> SequenceId: for update in ops: serialized_operation = update.SerializeToString() operation_size_bytes = len(serialized_operation) + + if operation_size_bytes > MAX_SINGLE_OPERATION_SIZE_BYTES: + raise RuntimeError(f"Operation size is too large: {operation_size_bytes} bytes") + params.append((current_time, OperationType.UPDATE_SNAPSHOT, serialized_operation, operation_size_bytes)) with self._get_connection() as conn: # type: ignore @@ -140,6 +151,11 @@ def save_create_run(self, run: CreateRun) -> SequenceId: return SequenceId(cursor.lastrowid) # type: ignore def get_operations(self, up_to_bytes: int) -> list[Operation]: + if up_to_bytes < MAX_SINGLE_OPERATION_SIZE_BYTES: + raise RuntimeError( + f"up to bytes is too small: {up_to_bytes} bytes, minimum is {MAX_SINGLE_OPERATION_SIZE_BYTES} bytes" + ) + with self._get_connection() as conn: # type: ignore cursor = conn.cursor() @@ -214,7 +230,7 @@ def save_metadata( count = cursor.fetchone()[0] if count > 0: - raise ValueError("Metadata already exists") + raise RuntimeError("Metadata already exists") # Insert new metadata cursor.execute( diff --git a/src/neptune_scale/sync/parameters.py b/src/neptune_scale/sync/parameters.py index f3d4b7fe..b318ca68 100644 --- a/src/neptune_scale/sync/parameters.py +++ b/src/neptune_scale/sync/parameters.py @@ -2,18 +2,9 @@ MAX_RUN_ID_LENGTH = 128 MAX_EXPERIMENT_NAME_LENGTH = 730 -# Operations queue -MAX_BATCH_SIZE = 100000 -MAX_QUEUE_SIZE = 1000000 -MAX_MULTIPROCESSING_QUEUE_SIZE = 32767 -MAX_QUEUE_ELEMENT_SIZE = 1024 * 1024 # 1MB -# Wait up to this many seconds for incoming operations before submitting a batch -BATCH_WAIT_TIME_SECONDS = 1 - # Threads -SYNC_THREAD_SLEEP_TIME = 2 +SYNC_THREAD_SLEEP_TIME = 0.5 STATUS_TRACKING_THREAD_SLEEP_TIME = 1 -INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME = 0.01 ERRORS_MONITOR_THREAD_SLEEP_TIME = 0.1 SYNC_PROCESS_SLEEP_TIME = 1 LAG_TRACKER_THREAD_SLEEP_TIME = 1 @@ -31,3 +22,7 @@ # Status tracking MAX_REQUESTS_STATUS_BATCH_SIZE = 1000 + +# Operations +MAX_SINGLE_OPERATION_SIZE_BYTES = 2 * 1024 * 1024 # 2MB +MAX_REQUEST_SIZE_BYTES = 16 * 1024 * 1024 # 16MB diff --git a/src/neptune_scale/sync/queue_element.py b/src/neptune_scale/sync/queue_element.py deleted file mode 100644 index 18e30de1..00000000 --- a/src/neptune_scale/sync/queue_element.py +++ /dev/null @@ -1,31 +0,0 @@ -__all__ = ("BatchedOperations", "SingleOperation") - -from collections.abc import Hashable -from typing import ( - NamedTuple, - Optional, -) - - -class BatchedOperations(NamedTuple): - # Operation identifier of the last operation in the batch - sequence_id: int - # Timestamp of the last operation in the batch - timestamp: float - # Protobuf serialized (RunOperationBatch) - operation: bytes - - -class SingleOperation(NamedTuple): - # Operation identifier - sequence_id: int - # Timestamp of the operation being enqueued - timestamp: float - # Protobuf serialized (RunOperation) - operation: bytes - # Whether the operation is batchable. Eg. metadata updates are, while run creations are not. - is_batchable: bool - # Size of the metadata in the operation (without project, family, run_id etc.) - metadata_size: Optional[int] - # Update metadata key - batch_key: Hashable diff --git a/src/neptune_scale/sync/sequence_tracker.py b/src/neptune_scale/sync/sequence_tracker.py new file mode 100644 index 00000000..a6fea748 --- /dev/null +++ b/src/neptune_scale/sync/sequence_tracker.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +__all__ = ("SequenceTracker",) + +import threading +import time +from typing import Optional + + +class SequenceTracker: + def __init__(self) -> None: + self._lock = threading.RLock() + self._last_sequence_id = -1 + self._last_timestamp: Optional[float] = None + + @property + def last_sequence_id(self) -> int: + with self._lock: + return self._last_sequence_id + + @property + def last_timestamp(self) -> Optional[float]: + with self._lock: + return self._last_timestamp + + def update_sequence_id(self, sequence_id: int) -> None: + with self._lock: + # Use max to ensure that the sequence ID is always increasing + self._last_sequence_id = max(self._last_sequence_id, sequence_id) + self._last_timestamp = time.time() diff --git a/src/neptune_scale/sync/sync_process.py b/src/neptune_scale/sync/sync_process.py index 2f7e645a..56b5d0d4 100644 --- a/src/neptune_scale/sync/sync_process.py +++ b/src/neptune_scale/sync/sync_process.py @@ -1,15 +1,24 @@ from __future__ import annotations +import logging +from pathlib import Path + +from neptune_scale.sync.operations_repository import ( + Metadata, + Operation, + OperationsRepository, + OperationType, + SequenceId, +) + __all__ = ("SyncProcess",) +import datetime import multiprocessing import queue import signal import threading -from multiprocessing import ( - Process, - Queue, -) +from multiprocessing import Process from types import FrameType from typing import ( Generic, @@ -21,6 +30,7 @@ import backoff from neptune_api.proto.google_rpc.code_pb2 import Code +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import UpdateRunSnapshots from neptune_api.proto.neptune_pb.ingest.v1.ingest_pb2 import IngestCode from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import ( BulkRequestStatus, @@ -38,7 +48,6 @@ NeptuneAttributeTypeUnsupported, NeptuneConnectionLostError, NeptuneInternalServerError, - NeptuneOperationsQueueMaxSizeExceeded, NeptunePreviewStepNotAfterLastCommittedStep, NeptuneProjectInvalidName, NeptuneProjectNotFound, @@ -65,22 +74,16 @@ backend_factory, with_api_errors_handling, ) -from neptune_scale.sync.aggregating_queue import AggregatingQueue from neptune_scale.sync.errors_tracking import ErrorsQueue from neptune_scale.sync.parameters import ( - INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME, - MAX_QUEUE_SIZE, MAX_REQUEST_RETRY_SECONDS, + MAX_REQUEST_SIZE_BYTES, MAX_REQUESTS_STATUS_BATCH_SIZE, SHUTDOWN_TIMEOUT, STATUS_TRACKING_THREAD_SLEEP_TIME, SYNC_PROCESS_SLEEP_TIME, SYNC_THREAD_SLEEP_TIME, ) -from neptune_scale.sync.queue_element import ( - BatchedOperations, - SingleOperation, -) from neptune_scale.sync.util import safe_signal_name from neptune_scale.util import ( Daemon, @@ -89,10 +92,6 @@ SharedInt, get_logger, ) -from neptune_scale.util.abstract import ( - Resource, - WithResources, -) T = TypeVar("T") @@ -124,8 +123,8 @@ class StatusTrackingElement(NamedTuple): - sequence_id: int - timestamp: float + sequence_id: SequenceId + timestamp: datetime.datetime request_id: str @@ -167,7 +166,7 @@ def commit(self, n: int) -> None: class SyncProcess(Process): def __init__( self, - operations_queue: Queue, + operations_repository_path: Path, errors_queue: ErrorsQueue, process_link: ProcessLink, api_token: str, @@ -177,11 +176,10 @@ def __init__( last_queued_seq: SharedInt, last_ack_seq: SharedInt, last_ack_timestamp: SharedFloat, - max_queue_size: int = MAX_QUEUE_SIZE, ) -> None: super().__init__(name="SyncProcess") - self._external_operations_queue: Queue[SingleOperation] = operations_queue + self._operations_repository_path: Path = operations_repository_path self._errors_queue: ErrorsQueue = errors_queue self._process_link: ProcessLink = process_link self._api_token: str = api_token @@ -190,7 +188,6 @@ def __init__( self._last_queued_seq: SharedInt = last_queued_seq self._last_ack_seq: SharedInt = last_ack_seq self._last_ack_timestamp: SharedFloat = last_ack_timestamp - self._max_queue_size: int = max_queue_size self._mode: Literal["async", "disabled"] = mode # This flag is set when a termination signal is caught @@ -210,157 +207,56 @@ def run(self) -> None: self._process_link.start(on_link_closed=self._on_parent_link_closed) signal.signal(signal.SIGTERM, self._handle_signal) - worker = SyncProcessWorker( - project=self._project, - family=self._family, - api_token=self._api_token, - errors_queue=self._errors_queue, - external_operations_queue=self._external_operations_queue, - last_queued_seq=self._last_queued_seq, - last_ack_seq=self._last_ack_seq, - max_queue_size=self._max_queue_size, - last_ack_timestamp=self._last_ack_timestamp, - mode=self._mode, - ) - worker.start() + status_tracking_queue: PeekableQueue[StatusTrackingElement] = PeekableQueue() + operations_repository = OperationsRepository(db_path=self._operations_repository_path) + threads = [ + SenderThread( + api_token=self._api_token, + operations_repository=operations_repository, + status_tracking_queue=status_tracking_queue, + errors_queue=self._errors_queue, + family=self._family, + last_queued_seq=self._last_queued_seq, + mode=self._mode, + ), + StatusTrackingThread( + api_token=self._api_token, + mode=self._mode, + project=self._project, + errors_queue=self._errors_queue, + status_tracking_queue=status_tracking_queue, + last_ack_seq=self._last_ack_seq, + last_ack_timestamp=self._last_ack_timestamp, + ), + ] + for thread in threads: + thread.start() + try: while not self._stop_event.is_set(): - worker.join(timeout=SYNC_PROCESS_SLEEP_TIME) + for thread in threads: + thread.join(timeout=SYNC_PROCESS_SLEEP_TIME) except KeyboardInterrupt: logger.debug("Data synchronization interrupted by user") finally: logger.info("Data synchronization stopping") - worker.interrupt() - worker.wake_up() - worker.join(timeout=SHUTDOWN_TIMEOUT) - worker.close() + for thread in threads: + thread.interrupt() + thread.wake_up() + + for thread in threads: + thread.join(timeout=SHUTDOWN_TIMEOUT) + thread.close() # type: ignore + operations_repository.close() logger.info("Data synchronization finished") -class SyncProcessWorker(WithResources): - def __init__( - self, - *, - api_token: str, - project: str, - family: str, - mode: Literal["async", "disabled"], - errors_queue: ErrorsQueue, - external_operations_queue: multiprocessing.Queue[SingleOperation], - last_queued_seq: SharedInt, - last_ack_seq: SharedInt, - last_ack_timestamp: SharedFloat, - max_queue_size: int = MAX_QUEUE_SIZE, - ) -> None: - self._errors_queue = errors_queue - - self._internal_operations_queue: AggregatingQueue = AggregatingQueue(max_queue_size=max_queue_size) - self._status_tracking_queue: PeekableQueue[StatusTrackingElement] = PeekableQueue() - self._sync_thread = SenderThread( - api_token=api_token, - operations_queue=self._internal_operations_queue, - status_tracking_queue=self._status_tracking_queue, - errors_queue=self._errors_queue, - family=family, - last_queued_seq=last_queued_seq, - mode=mode, - ) - self._external_to_internal_thread = InternalQueueFeederThread( - external=external_operations_queue, - internal=self._internal_operations_queue, - errors_queue=self._errors_queue, - ) - self._status_tracking_thread = StatusTrackingThread( - api_token=api_token, - mode=mode, - project=project, - errors_queue=self._errors_queue, - status_tracking_queue=self._status_tracking_queue, - last_ack_seq=last_ack_seq, - last_ack_timestamp=last_ack_timestamp, - ) - - @property - def threads(self) -> tuple[Daemon, ...]: - return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread - - @property - def resources(self) -> tuple[Resource, ...]: - return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread - - def interrupt(self) -> None: - for thread in self.threads: - thread.interrupt() - - def wake_up(self) -> None: - for thread in self.threads: - thread.wake_up() - - def start(self) -> None: - for thread in self.threads: - thread.start() - - def join(self, timeout: Optional[int] = None) -> None: - # The same timeout will be applied to each thread separately - for thread in self.threads: - thread.join(timeout=timeout) - - -class InternalQueueFeederThread(Daemon, Resource): - def __init__( - self, - external: multiprocessing.Queue[SingleOperation], - internal: AggregatingQueue, - errors_queue: ErrorsQueue, - ) -> None: - super().__init__(name="InternalQueueFeederThread", sleep_time=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) - - self._external: multiprocessing.Queue[SingleOperation] = external - self._internal: AggregatingQueue = internal - self._errors_queue: ErrorsQueue = errors_queue - - self._latest_unprocessed: Optional[SingleOperation] = None - - def get_next(self) -> Optional[SingleOperation]: - if self._latest_unprocessed is not None: - return self._latest_unprocessed - - try: - self._latest_unprocessed = self._external.get(timeout=INTERNAL_QUEUE_FEEDER_THREAD_SLEEP_TIME) - return self._latest_unprocessed - except queue.Empty: - return None - - def commit(self) -> None: - self._latest_unprocessed = None - - def work(self) -> None: - try: - while not self._is_interrupted(): - operation = self.get_next() - if operation is None: - continue - - try: - self._internal.put_nowait(operation) - self.commit() - except queue.Full: - logger.debug("Internal queue is full (%d elements), waiting for free space", self._internal.maxsize) - self._errors_queue.put(NeptuneOperationsQueueMaxSizeExceeded(max_size=self._internal.maxsize)) - # Sleep before retry - break - except Exception as e: - self._errors_queue.put(e) - self.interrupt() - raise NeptuneSynchronizationStopped() from e - - -class SenderThread(Daemon, WithResources): +class SenderThread(Daemon): def __init__( self, api_token: str, family: str, - operations_queue: AggregatingQueue, + operations_repository: OperationsRepository, status_tracking_queue: PeekableQueue[StatusTrackingElement], errors_queue: ErrorsQueue, last_queued_seq: SharedInt, @@ -370,33 +266,14 @@ def __init__( self._api_token: str = api_token self._family: str = family - self._operations_queue: AggregatingQueue = operations_queue + self._operations_repository: OperationsRepository = operations_repository self._status_tracking_queue: PeekableQueue[StatusTrackingElement] = status_tracking_queue self._errors_queue: ErrorsQueue = errors_queue self._last_queued_seq: SharedInt = last_queued_seq self._mode: Literal["async", "disabled"] = mode self._backend: Optional[ApiClient] = None - self._latest_unprocessed: Optional[BatchedOperations] = None - - def get_next(self) -> Optional[BatchedOperations]: - if self._latest_unprocessed is not None: - return self._latest_unprocessed - - try: - self._latest_unprocessed = self._operations_queue.get() - return self._latest_unprocessed - except queue.Empty: - return None - - def commit(self) -> None: - self._latest_unprocessed = None - - @property - def resources(self) -> tuple[Resource, ...]: - if self._backend is not None: - return (self._backend,) - return () + self._metadata: Metadata = operations_repository.get_metadata() # type: ignore @backoff.on_exception(backoff.expo, NeptuneRetryableError, max_time=MAX_REQUEST_RETRY_SECONDS) @with_api_errors_handling @@ -414,34 +291,46 @@ def submit(self, *, operation: RunOperation) -> Optional[SubmitResponse]: def work(self) -> None: try: - while (operation := self.get_next()) is not None: - sequence_id, timestamp, data = operation - - try: - logger.debug("Submitting operation #%d with size of %d bytes", sequence_id, len(data)) - run_operation = RunOperation() - run_operation.ParseFromString(data) - request_ids: Optional[SubmitResponse] = self.submit(operation=run_operation) - - if request_ids is None or not request_ids.request_ids: - raise NeptuneUnexpectedError("Server response is empty") + max_operations_size = ( + MAX_REQUEST_SIZE_BYTES + - len(self._metadata.run_id) + - len(self._metadata.project) + - 200 # 200 bytes for RunOperation overhead, + ) + while operations := self._operations_repository.get_operations(up_to_bytes=max_operations_size): + partitioned_operations = _partition_by_type_and_size( + operations, self._metadata.run_id, self._metadata.project, max_operations_size + ) + for run_operation, sequence_id, timestamp in partitioned_operations: + try: + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + "Submitting operation #%d with size of %d bytes", sequence_id, run_operation.ByteSize() + ) + request_ids: Optional[SubmitResponse] = self.submit(operation=run_operation) + + if request_ids is None or not request_ids.request_ids: + raise NeptuneUnexpectedError("Server response is empty") + + last_request_id = request_ids.request_ids[-1] + + logger.debug("Operation #%d submitted as %s", sequence_id, last_request_id) + self._status_tracking_queue.put( + StatusTrackingElement( + sequence_id=sequence_id, request_id=last_request_id, timestamp=timestamp + ) + ) + + self._operations_repository.delete_operations(up_to_seq_id=sequence_id) + # Update Last PUT sequence id and notify threads in the main process + with self._last_queued_seq: + self._last_queued_seq.value = sequence_id + self._last_queued_seq.notify_all() + except NeptuneRetryableError as e: + self._errors_queue.put(e) + # Sleep before retry + return - last_request_id = request_ids.request_ids[-1] - - logger.debug("Operation #%d submitted as %s", sequence_id, last_request_id) - self._status_tracking_queue.put( - StatusTrackingElement(sequence_id=sequence_id, request_id=last_request_id, timestamp=timestamp) - ) - self.commit() - except NeptuneRetryableError as e: - self._errors_queue.put(e) - # Sleep before retry - break - - # Update Last PUT sequence id and notify threads in the main process - with self._last_queued_seq: - self._last_queued_seq.value = sequence_id - self._last_queued_seq.notify_all() except Exception as e: self._errors_queue.put(e) with self._last_queued_seq: @@ -449,6 +338,10 @@ def work(self) -> None: self.interrupt() raise NeptuneSynchronizationStopped() from e + def close(self) -> None: + if self._backend is not None: + self._backend.close() + def _raise_exception(status_code: int) -> None: logger.error("HTTP response error: %s", status_code) @@ -464,7 +357,54 @@ def _raise_exception(status_code: int) -> None: raise NeptuneUnexpectedResponseError() -class StatusTrackingThread(Daemon, WithResources): +def _partition_by_type_and_size( + operations: list[Operation], run_id: str, project: str, max_batch_size: int +) -> list[tuple[RunOperation, SequenceId, datetime.datetime]]: + grouped: list[list[Operation]] = [] + batch: list[Operation] = [] + batch_type: Optional[OperationType] = None + batch_size = 0 + + def op_size_with_overhead(_op: Operation) -> int: + return _op.operation_size_bytes + 5 # 1 byte for wire type, 4 bytes for size (2mb) + + for op in operations: + reset_batch = ( + # we don't mix operation types in a single batch + op.operation_type != batch_type + # only one CREATE_RUN per batch + or batch_type == OperationType.CREATE_RUN + # batch cannot be too big + or batch_size + op_size_with_overhead(op) > max_batch_size + ) + if reset_batch: + if batch: + grouped.append(batch) + batch = [] + batch_type = op.operation_type + batch_size = 0 + + batch.append(op) + batch_size += op_size_with_overhead(op) + + if batch: + grouped.append(batch) + + def to_run_operation(ops: list[Operation]) -> tuple[RunOperation, SequenceId, datetime.datetime]: + if ops[0].operation_type == OperationType.CREATE_RUN: + return ( + RunOperation(project=project, run_id=run_id, create=ops[0].operation), # type: ignore + ops[-1].sequence_id, + ops[-1].ts, + ) + else: + snapshots = UpdateRunSnapshots(snapshots=[_op.operation for _op in ops]) # type: ignore + return RunOperation(project=project, run_id=run_id, update_batch=snapshots), ops[-1].sequence_id, ops[-1].ts + + return [(to_run_operation(ops)) for ops in grouped] # type: ignore + + +class StatusTrackingThread(Daemon): def __init__( self, api_token: str, @@ -487,11 +427,9 @@ def __init__( self._backend: Optional[ApiClient] = None - @property - def resources(self) -> tuple[Resource, ...]: + def close(self) -> None: if self._backend is not None: - return (self._backend,) - return () + self._backend.close() def get_next(self) -> Optional[list[StatusTrackingElement]]: try: @@ -559,7 +497,7 @@ def work(self) -> None: # Update Last ACK timestamp and notify threads in the main process if processed_timestamp is not None: with self._last_ack_timestamp: - self._last_ack_timestamp.value = processed_timestamp + self._last_ack_timestamp.value = processed_timestamp.timestamp() self._last_ack_timestamp.notify_all() else: # Sleep before retry diff --git a/src/neptune_scale/util/abstract.py b/src/neptune_scale/util/abstract.py deleted file mode 100644 index 80538c36..00000000 --- a/src/neptune_scale/util/abstract.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -from abc import ( - ABC, - abstractmethod, -) -from types import TracebackType -from typing import Optional - - -class AutoCloseable(ABC): - def __enter__(self) -> AutoCloseable: - return self - - @abstractmethod - def close(self) -> None: ... - - def __exit__( - self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> None: - self.close() - - -class Resource(AutoCloseable): - def cleanup(self) -> None: - pass - - def flush(self) -> None: - pass - - def close(self) -> None: - self.flush() - - -class WithResources(Resource): - @property - @abstractmethod - def resources(self) -> tuple[Resource, ...]: ... - - def flush(self) -> None: - for resource in self.resources: - resource.flush() - - def close(self) -> None: - for resource in self.resources: - resource.close() - - def cleanup(self) -> None: - for resource in self.resources: - resource.cleanup() diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index db600aef..7820eec6 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -46,17 +46,6 @@ def project(request): return ReadOnlyProject(project=project_name) -class SyncRun(Run): - """A neptune_scale.Run instance that waits for processing to complete - after each logging method call. This is useful for e2e tests, where we - usually want to wait for the data to be available before fetching it.""" - - def _log(self, *args, **kwargs): - result = super()._log(*args, **kwargs) - self.wait_for_processing() - return result - - @fixture(scope="module") def run_init_kwargs(project): """Arguments to initialize a neptune_scale.Run instance""" @@ -85,13 +74,9 @@ def run(project, run_init_kwargs): run = Run(**run_init_kwargs) run.log_configs({"test_start_time": datetime.now(timezone.utc)}) - return run + yield run - -@fixture(scope="module") -def sync_run(project, run, run_init_kwargs): - """Blocking run for logging data""" - return SyncRun(project=run_init_kwargs["project"], run_id=run_init_kwargs["run_id"], resume=True) + run.terminate() @fixture diff --git a/tests/e2e/test_log_and_fetch.py b/tests/e2e/test_log_and_fetch.py index fc132bf0..51427894 100644 --- a/tests/e2e/test_log_and_fetch.py +++ b/tests/e2e/test_log_and_fetch.py @@ -1,6 +1,7 @@ import math import os import random +import threading import time import uuid from datetime import ( @@ -12,7 +13,10 @@ from neptune_fetcher import ReadOnlyRun from pytest import mark +from neptune_scale.api.run import Run + NEPTUNE_PROJECT = os.getenv("NEPTUNE_E2E_PROJECT") +SYNC_TIMEOUT = 30 def unique_path(prefix): @@ -38,7 +42,7 @@ def random_series(length=10, start_step=0): return steps, values -def test_atoms(sync_run, ro_run): +def test_atoms(run, ro_run): """Set atoms to a value, make sure it's equal when fetched""" now = time.time() @@ -52,7 +56,8 @@ def test_atoms(sync_run, ro_run): "datetime-value": datetime.now(timezone.utc).replace(microsecond=0), } - sync_run.log_configs(data) + run.log_configs(data) + run.wait_for_processing(SYNC_TIMEOUT) for key, value in data.items(): assert ro_run[key].fetch() == value, f"Value for {key} does not match" @@ -67,7 +72,8 @@ def test_atoms(sync_run, ro_run): "datetime-value": datetime.now(timezone.utc).replace(year=1999, microsecond=0), } - sync_run.log_configs(updated_data) + run.log_configs(updated_data) + run.wait_for_processing(SYNC_TIMEOUT) # The data should stay the same, as we haven't purged the cache yet for key, value in data.items(): @@ -89,7 +95,7 @@ def test_series_no_prefetch(run, ro_run): for step, value in zip(steps, values): run.log_metrics(data={path: value}, step=step) - run.wait_for_processing() + run.wait_for_processing(SYNC_TIMEOUT) df = ro_run[path].fetch_values() assert df["step"].tolist() == steps @@ -104,7 +110,7 @@ def test_single_series_with_prefetch(run, ro_run): for step, value in zip(steps, values): run.log_metrics(data={path: value}, step=step) - run.wait_for_processing() + run.wait_for_processing(SYNC_TIMEOUT) ro_run.prefetch_series_values([path], use_threads=True) df = ro_run[path].fetch_values() @@ -118,7 +124,7 @@ def test_multiple_series_with_prefetch(run, ro_run): data = {f"{path_base}-{i}": i for i in range(20)} run.log_metrics(data, step=1) - run.wait_for_processing() + run.wait_for_processing(SYNC_TIMEOUT) ro_run = refresh(ro_run) paths = [p for p in ro_run.field_names if p.startswith(path_base)] @@ -141,7 +147,7 @@ def test_series_fetch_and_append(run, ro_run): for step, value in zip(steps, values): run.log_metrics(data={path: value}, step=step) - run.wait_for_processing() + run.wait_for_processing(SYNC_TIMEOUT) df = ro_run[path].fetch_values() assert df["step"].tolist() == steps @@ -152,7 +158,7 @@ def test_series_fetch_and_append(run, ro_run): for step, value in zip(steps2, values2): run.log_metrics(data={path: value}, step=step) - run.wait_for_processing() + run.wait_for_processing(SYNC_TIMEOUT) df = ro_run[path].fetch_values() assert df["step"].tolist() == steps + steps2 @@ -160,7 +166,36 @@ def test_series_fetch_and_append(run, ro_run): @mark.parametrize("value", [np.inf, -np.inf, np.nan, math.inf, -math.inf, math.nan]) -def test_single_non_finite_metric(value, sync_run, ro_run): +def test_single_non_finite_metric(value, run, ro_run): path = unique_path("test_series/non_finite") - sync_run.log_metrics(data={path: value}, step=1) + + run.log_metrics(data={path: value}, step=1) + run.wait_for_processing(SYNC_TIMEOUT) assert path not in refresh(ro_run).field_names + + +def test_async_lag_callback(): + event = threading.Event() + with Run( + project=NEPTUNE_PROJECT, + run_id=f"{uuid.uuid4()}", + async_lag_threshold=0.000001, + on_async_lag_callback=lambda: event.set(), + ) as run: + run.wait_for_processing(SYNC_TIMEOUT) + + # First callback should be called after run creation + event.wait(timeout=60) + assert event.is_set() + event.clear() + + run.log_configs( + data={ + "parameters/learning_rate": 0.001, + "parameters/batch_size": 64, + }, + ) + + # Second callback should be called after logging configs + event.wait(timeout=60) + assert event.is_set() diff --git a/tests/unit/sync/test_operations_repository.py b/tests/unit/sync/test_operations_repository.py index 132126a6..d68c04fa 100644 --- a/tests/unit/sync/test_operations_repository.py +++ b/tests/unit/sync/test_operations_repository.py @@ -16,6 +16,7 @@ OperationsRepository, OperationType, ) +from neptune_scale.sync.parameters import MAX_SINGLE_OPERATION_SIZE_BYTES @pytest.fixture @@ -101,7 +102,7 @@ def test_save_create_run(operations_repo, temp_db_path): count = get_operation_count(temp_db_path) assert count == 1 - operation = operations_repo.get_operations(up_to_bytes=10000)[0] + operation = operations_repo.get_operations(up_to_bytes=MAX_SINGLE_OPERATION_SIZE_BYTES)[0] assert operation.operation_type == OperationType.CREATE_RUN assert operation.sequence_id == 1 @@ -112,8 +113,7 @@ def test_get_operations(operations_repo): # Given snapshots = [] for i in range(5): - # Create snapshots with increasing sizes - up to 5MB - snapshot = UpdateRunSnapshot(assign={f"key_{i}": Value(string="a" * (1024 * 1024 * (i + 1)))}) + snapshot = UpdateRunSnapshot(assign={f"key_{i}": Value(string="a" * (1024 * 1024 * 2 - 100))}) snapshots.append(snapshot) operations_repo.save_update_run_snapshots(snapshots) @@ -139,7 +139,7 @@ def test_get_operations_size_based_pagination_with_many_items(operations_repo): operations_count = 150_000 snapshots = [] for i in range(operations_count): - snapshot = UpdateRunSnapshot(assign={f"key_{i}": Value(string=f"{i}")}) + snapshot = UpdateRunSnapshot(assign={f"key_{i}": Value(string=f"{i}" * 50)}) snapshots.append(snapshot) operations_repo.save_update_run_snapshots(snapshots) @@ -156,7 +156,7 @@ def test_get_operations_size_based_pagination_with_many_items(operations_repo): def test_get_operations_empty_db(operations_repo): # Given - operations = operations_repo.get_operations(up_to_bytes=10000) + operations = operations_repo.get_operations(up_to_bytes=MAX_SINGLE_OPERATION_SIZE_BYTES) assert len(operations) == 0 @@ -170,7 +170,7 @@ def test_delete_operations(operations_repo, temp_db_path): operations_repo.save_update_run_snapshots(snapshots) # Get the operations to find their sequence IDs - operations = operations_repo.get_operations(up_to_bytes=10000) + operations = operations_repo.get_operations(up_to_bytes=MAX_SINGLE_OPERATION_SIZE_BYTES) assert len(operations) == 5 # When - delete the first 3 operations @@ -229,7 +229,7 @@ def test_get_metadata_nonexistent(operations_repo): def test_metadata_already_exists_error(operations_repo): operations_repo.save_metadata(project="test", run_id="test") - with pytest.raises(ValueError, match="Metadata already exists"): + with pytest.raises(RuntimeError, match="Metadata already exists"): operations_repo.save_metadata(project="test2", run_id="test2") @@ -266,6 +266,18 @@ def test_timestamp_in_operations(mock_time, operations_repo): assert timestamp == expected_timestamp +def test_get_operations_up_to_bytes_too_small(operations_repo): + with pytest.raises(RuntimeError, match=r"up to bytes is too small: 100 bytes.*"): + operations_repo.get_operations(up_to_bytes=100) + + +def test_save_update_run_snapshots_too_large(operations_repo): + with pytest.raises(RuntimeError, match="Operation size is too large: 2097172 bytes"): + operations_repo.save_update_run_snapshots( + [UpdateRunSnapshot(assign={"key": Value(string="a" * 1024 * 1024 * 2)})] + ) + + def get_operation_count(db_path: str) -> int: conn = sqlite3.connect(db_path) try: diff --git a/tests/unit/sync/test_sequence_tracker.py b/tests/unit/sync/test_sequence_tracker.py new file mode 100644 index 00000000..addc8b23 --- /dev/null +++ b/tests/unit/sync/test_sequence_tracker.py @@ -0,0 +1,23 @@ +from unittest.mock import patch + +from neptune_scale.sync.sequence_tracker import SequenceTracker + + +def test_update_sequence_id(): + with patch("time.time", return_value=123.456): + tracker = SequenceTracker() + + # Update with a positive sequence ID + tracker.update_sequence_id(5) + assert tracker.last_sequence_id == 5 + assert tracker.last_timestamp == 123.456 + + # Update with a higher sequence ID + tracker.update_sequence_id(10) + assert tracker.last_sequence_id == 10 + assert tracker.last_timestamp == 123.456 + + # Update with a lower sequence ID + tracker.update_sequence_id(7) + assert tracker.last_sequence_id == 10 # Should not decrease + assert tracker.last_timestamp == 123.456 diff --git a/tests/unit/test_aggregating_queue.py b/tests/unit/test_aggregating_queue.py deleted file mode 100644 index de9394a0..00000000 --- a/tests/unit/test_aggregating_queue.py +++ /dev/null @@ -1,566 +0,0 @@ -import time -from queue import ( - Empty, - Full, -) - -import pytest -from freezegun import freeze_time -from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun -from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( - Step, - UpdateRunSnapshot, - Value, -) -from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation - -from neptune_scale.sync.aggregating_queue import AggregatingQueue -from neptune_scale.sync.queue_element import ( - BatchedOperations, - SingleOperation, -) - - -@freeze_time("2024-09-01") -def test__simple(): - # given - update = UpdateRunSnapshot(assign={f"key_{i}": Value(string=("a" * 2)) for i in range(2)}) - operation = RunOperation(update=update) - element = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation.SerializeToString(), - is_batchable=True, - metadata_size=update.ByteSize(), - batch_key=None, - ) - - # and - queue = AggregatingQueue(max_queue_size=1) - - # when - queue.put_nowait(element=element) - - # then - assert queue.get() == BatchedOperations( - sequence_id=1, - timestamp=element.timestamp, - operation=element.operation, - ) - - -@freeze_time("2024-09-01") -def test__max_size_exceeded(): - # given - operation1 = RunOperation() - operation2 = RunOperation() - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=True, - metadata_size=0, - batch_key=None, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=True, - metadata_size=0, - batch_key=None, - ) - - # and - queue = AggregatingQueue(max_queue_size=1) - - # when - queue.put_nowait(element=element1) - - # then - assert True - - # when - with pytest.raises(Full): - queue.put_nowait(element=element2) - - -@freeze_time("2024-09-01") -def test__empty(): - # given - queue = AggregatingQueue(max_queue_size=1) - - # when - with pytest.raises(Empty): - _ = queue.get() - - -@freeze_time("2024-09-01") -def test__batch_size_limit(): - # given - update1 = UpdateRunSnapshot(step=None, assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) - update2 = UpdateRunSnapshot(step=None, assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) - operation1 = RunOperation(update=update1) - operation2 = RunOperation(update=update2) - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=True, - metadata_size=update1.ByteSize(), - batch_key=None, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=True, - metadata_size=update2.ByteSize(), - batch_key=None, - ) - - # and - queue = AggregatingQueue(max_queue_size=2, max_elements_in_batch=1) - - # when - queue.put_nowait(element=element1) - queue.put_nowait(element=element2) - - # then - assert queue.get() == BatchedOperations(sequence_id=1, timestamp=element1.timestamp, operation=element1.operation) - assert queue.get() == BatchedOperations(sequence_id=2, timestamp=element2.timestamp, operation=element2.operation) - - -@freeze_time("2024-09-01") -def test__batching(): - # given - update1 = UpdateRunSnapshot(step=None, assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) - update2 = UpdateRunSnapshot(step=None, assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) - - # and - operation1 = RunOperation(update=update1, project="project", run_id="run_id") - operation2 = RunOperation(update=update2, project="project", run_id="run_id") - - # and - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=True, - metadata_size=update1.ByteSize(), - batch_key=None, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=True, - metadata_size=update2.ByteSize(), - batch_key=None, - ) - - # and - queue = AggregatingQueue(max_queue_size=2, max_elements_in_batch=2) - - # and - queue.put_nowait(element=element1) - queue.put_nowait(element=element2) - - # when - result = queue.get() - - # then - assert result.sequence_id == 2 - assert result.timestamp == element2.timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - assert batch.project == "project" - assert batch.run_id == "run_id" - assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) - - -@freeze_time("2024-09-01") -def test__queue_element_size_limit_with_different_steps(): - # given - update1 = UpdateRunSnapshot(step=Step(whole=1), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) - update2 = UpdateRunSnapshot(step=Step(whole=2), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) - operation1 = RunOperation(update=update1) - operation2 = RunOperation(update=update2) - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=True, - metadata_size=update1.ByteSize(), - batch_key=1.0, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=True, - metadata_size=update2.ByteSize(), - batch_key=2.0, - ) - - # and - queue = AggregatingQueue(max_queue_size=2, max_queue_element_size=update1.ByteSize()) - - # when - queue.put_nowait(element=element1) - queue.put_nowait(element=element2) - - # then - assert queue.get() == BatchedOperations(sequence_id=1, timestamp=element1.timestamp, operation=element1.operation) - assert queue.get() == BatchedOperations(sequence_id=2, timestamp=element2.timestamp, operation=element2.operation) - - -@freeze_time("2024-09-01") -def test__not_merge_two_run_creation(): - # given - create1 = CreateRun(family="family", run_id="run_id1") - create2 = CreateRun(family="family", run_id="run_id2") - - # and - operation1 = RunOperation(create=create1, project="project", run_id="run_id1") - operation2 = RunOperation(create=create2, project="project", run_id="run_id2") - - # and - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=False, - metadata_size=0, - batch_key=None, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=False, - metadata_size=0, - batch_key=None, - ) - - # and - queue = AggregatingQueue(max_queue_size=2, max_elements_in_batch=2) - - # and - queue.put_nowait(element=element1) - queue.put_nowait(element=element2) - - # when - result = queue.get() - - # then - assert result.sequence_id == 1 - assert result.timestamp == element1.timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - assert batch.project == "project" - assert batch.run_id == "run_id1" - assert batch.create == create1 - - # when - result = queue.get() - - # then - assert result.sequence_id == 2 - assert result.timestamp == element2.timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - assert batch.project == "project" - assert batch.run_id == "run_id2" - assert batch.create == create2 - - -@freeze_time("2024-09-01") -def test__not_merge_run_creation_with_metadata_update(): - # given - create = CreateRun(family="family", run_id="run_id") - update = UpdateRunSnapshot(step=None, assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) - - # and - operation1 = RunOperation(create=create, project="project", run_id="run_id") - operation2 = RunOperation(update=update, project="project", run_id="run_id") - - # and - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=False, - metadata_size=0, - batch_key=None, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=True, - metadata_size=update.ByteSize(), - batch_key=None, - ) - - # and - queue = AggregatingQueue(max_queue_size=2, max_elements_in_batch=2) - - # and - queue.put_nowait(element=element1) - queue.put_nowait(element=element2) - - # when - result = queue.get() - - # then - assert result.sequence_id == 1 - assert result.timestamp == element1.timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - assert batch.project == "project" - assert batch.run_id == "run_id" - assert batch.create == create - - # when - result = queue.get() - - # then - assert result.sequence_id == 2 - assert result.timestamp == element2.timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - assert batch.project == "project" - assert batch.run_id == "run_id" - assert batch.update == update - - -@freeze_time("2024-09-01") -def test__merge_same_key(): - # given - update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) - update2 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) - - # and - operation1 = RunOperation(update=update1, project="project", run_id="run_id") - operation2 = RunOperation(update=update2, project="project", run_id="run_id") - - # and - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=True, - metadata_size=update1.ByteSize(), - batch_key=1.0, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=True, - metadata_size=update2.ByteSize(), - batch_key=1.0, - ) - - # and - queue = AggregatingQueue(max_queue_size=2, max_elements_in_batch=2) - - # and - queue.put_nowait(element=element1) - queue.put_nowait(element=element2) - - # when - result = queue.get() - - # then - assert result.sequence_id == 2 - assert result.timestamp == element2.timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - assert batch.project == "project" - assert batch.run_id == "run_id" - assert batch.update.step == Step(whole=1, micro=0) - assert all(k in batch.update.assign for k in ["aa0", "aa1", "bb0", "bb1"]) - - -@freeze_time("2024-09-01") -def test__merge_two_different_steps(): - # given - update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) - update2 = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) - - # and - operation1 = RunOperation(update=update1, project="project", run_id="run_id") - operation2 = RunOperation(update=update2, project="project", run_id="run_id") - - # and - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=True, - metadata_size=0, - batch_key=1.0, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=True, - metadata_size=0, - batch_key=2.0, - ) - - # and - queue = AggregatingQueue(max_queue_size=2, max_elements_in_batch=2) - - # and - queue.put_nowait(element=element1) - queue.put_nowait(element=element2) - - # when - result = queue.get() - - # then - assert result.sequence_id == element2.sequence_id - assert result.timestamp == element2.timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - assert batch.project == "project" - assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update1, update2] - - -@freeze_time("2024-09-01") -def test__merge_step_with_none(): - # given - update1 = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={f"aa{i}": Value(int64=(i * 97)) for i in range(2)}) - update2 = UpdateRunSnapshot(step=None, assign={f"bb{i}": Value(int64=(i * 25)) for i in range(2)}) - - # and - operation1 = RunOperation(update=update1, project="project", run_id="run_id") - operation2 = RunOperation(update=update2, project="project", run_id="run_id") - - # and - element1 = SingleOperation( - sequence_id=1, - timestamp=time.process_time(), - operation=operation1.SerializeToString(), - is_batchable=True, - metadata_size=0, - batch_key=1.0, - ) - element2 = SingleOperation( - sequence_id=2, - timestamp=time.process_time(), - operation=operation2.SerializeToString(), - is_batchable=True, - metadata_size=0, - batch_key=None, - ) - - # and - queue = AggregatingQueue(max_queue_size=2, max_elements_in_batch=2) - - # and - queue.put_nowait(element=element1) - queue.put_nowait(element=element2) - - # when - result = queue.get() - - # then - assert result.sequence_id == element2.sequence_id - assert result.timestamp == element2.timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - assert batch.project == "project" - assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update2, update1] # None is always first - - -@freeze_time("2024-09-01") -def test__merge_two_steps_two_metrics(): - # given - update1a = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={"aa": Value(int64=10)}) - update2a = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={"aa": Value(int64=20)}) - update1b = UpdateRunSnapshot(step=Step(whole=1, micro=0), assign={"bb": Value(int64=100)}) - update2b = UpdateRunSnapshot(step=Step(whole=2, micro=0), assign={"bb": Value(int64=200)}) - - # and - operations = [ - RunOperation(update=update, project="project", run_id="run_id") - for update in [update1a, update2a, update1b, update2b] - ] - - # and - elements = [ - SingleOperation( - sequence_id=sequence_id, - timestamp=time.process_time(), - operation=operation.SerializeToString(), - is_batchable=True, - metadata_size=0, - batch_key=batch_key, - ) - for sequence_id, batch_key, operation in [ - (1, 1.0, operations[0]), - (2, 2.0, operations[1]), - (3, 1.0, operations[2]), - (4, 2.0, operations[3]), - ] - ] - - # and - queue = AggregatingQueue(max_queue_size=4, max_elements_in_batch=4) - - # and - for element in elements: - queue.put_nowait(element=element) - - # when - result = queue.get() - - # then - assert result.sequence_id == elements[-1].sequence_id - assert result.timestamp == elements[-1].timestamp - - # and - batch = RunOperation() - batch.ParseFromString(result.operation) - - update1_merged = UpdateRunSnapshot( - step=Step(whole=1, micro=0), assign={"aa": Value(int64=10), "bb": Value(int64=100)} - ) - update2_merged = UpdateRunSnapshot( - step=Step(whole=2, micro=0), assign={"aa": Value(int64=20), "bb": Value(int64=200)} - ) - - assert batch.project == "project" - assert batch.run_id == "run_id" - assert batch.update_batch.snapshots == [update1_merged, update2_merged] diff --git a/tests/unit/test_attribute.py b/tests/unit/test_attribute.py index c875ead2..15c82244 100644 --- a/tests/unit/test_attribute.py +++ b/tests/unit/test_attribute.py @@ -1,3 +1,4 @@ +import uuid from datetime import ( datetime, timedelta, @@ -21,7 +22,7 @@ @fixture def run(api_token): - run = Run(project="dummy/project", run_id="dummy-run", mode="disabled", api_token=api_token) + run = Run(project="dummy/project", run_id=f"{uuid.uuid4()}", mode="disabled", api_token=api_token) run._attr_store.log = Mock() with run: yield run diff --git a/tests/unit/test_errors_monitor.py b/tests/unit/test_errors_monitor.py index fe11b78e..c43fdc34 100644 --- a/tests/unit/test_errors_monitor.py +++ b/tests/unit/test_errors_monitor.py @@ -52,7 +52,6 @@ def callback_with_event(*args, **kwargs) -> None: # when errors_queue.put(error) - errors_queue.flush() errors_monitor.wake_up() # then diff --git a/tests/unit/test_integration_batching.py b/tests/unit/test_integration_batching.py deleted file mode 100644 index d4f6d844..00000000 --- a/tests/unit/test_integration_batching.py +++ /dev/null @@ -1,100 +0,0 @@ -import threading -from datetime import datetime - -import pytest -from freezegun import freeze_time -from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( - Preview, - UpdateRunSnapshot, -) -from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation - -from neptune_scale.api.attribute import AttributeStore -from neptune_scale.api.metrics import Metrics -from neptune_scale.net.serialization import ( - datetime_to_proto, - make_step, - make_value, -) -from neptune_scale.sync.aggregating_queue import AggregatingQueue -from neptune_scale.sync.operations_queue import OperationsQueue - - -@pytest.mark.parametrize( - "metrics,expected_updates", - [ - pytest.param( - [ - Metrics(data={"x": 1, "y": 5}, step=1), - Metrics(data={"a": 2}, step=1), - ], - [ - {"step": 1, "append": {"a": 2, "x": 1, "y": 5}}, - ], - id="Different metrics, same step", - ), - pytest.param( - [ - Metrics(data={"a": 1, "b": 2}, step=1), - Metrics(data={"a": 2}, step=2), - ], - [ - {"step": 1, "append": {"a": 1, "b": 2}}, - {"step": 2, "append": {"a": 2}}, - ], - id="Different step", - ), - pytest.param( - [ - Metrics(data={"a": 1, "b": 2}, step=1, preview=True, preview_completion=0.2), - Metrics(data={"a": 10, "b": 20}, step=1, preview=True, preview_completion=0.8), - Metrics(data={"a": 100, "b": 200}, step=1), - ], - [ - {"step": 1, "append": {"a": 1, "b": 2}, "preview": True, "preview_completion": 0.2}, - {"step": 1, "append": {"a": 10, "b": 20}, "preview": True, "preview_completion": 0.8}, - {"step": 1, "append": {"a": 100, "b": 200}}, - ], - id="Multiple previews for same point", - ), - ], -) -@freeze_time("2025-02-01") -def test__merge_metrics(metrics, expected_updates): - # given - op_queue = OperationsQueue(lock=threading.RLock(), max_size=1000) - store = AttributeStore("project", "run_id", op_queue) - agg_queue = AggregatingQueue(1000) - - # when - for m in metrics: - store.log(metrics=m) - agg_queue.put_nowait(op_queue.queue.get()) - - result = agg_queue.get() - batch = RunOperation() - batch.ParseFromString(result.operation) - - # then - assert batch.project == "project" - assert batch.run_id == "run_id" - - results = batch.update_batch.snapshots if batch.update_batch.snapshots else [batch.update] - assert len(results) == len(expected_updates) - for expected in expected_updates: - preview = ( - Preview(is_preview=expected["preview"], completion_ratio=expected.get("preview_completion", 0.0)) - if "preview" in expected - else None - ) - exp_proto = UpdateRunSnapshot( - step=make_step(expected.get("step")), - timestamp=datetime_to_proto(datetime.now()), - append={k: make_value(float(v)) for k, v in expected.get("append", {}).items()}, - preview=preview, - ) - for got in results: - if got == exp_proto: - break - else: - pytest.fail(f"didn't find expected result: {exp_proto}") diff --git a/tests/unit/test_lag_tracker.py b/tests/unit/test_lag_tracker.py index d724246f..6c710af7 100644 --- a/tests/unit/test_lag_tracker.py +++ b/tests/unit/test_lag_tracker.py @@ -5,6 +5,7 @@ from freezegun import freeze_time from neptune_scale.sync.lag_tracking import LagTracker +from neptune_scale.sync.sequence_tracker import SequenceTracker from neptune_scale.util import SharedFloat @@ -16,7 +17,8 @@ def test__lag_tracker__callback_called(): # and errors_queue = Mock() - operations_queue = Mock(last_timestamp=time.time()) + sequence_tracker = SequenceTracker() + sequence_tracker.update_sequence_id(1) # This will set last_timestamp last_ack_timestamp = SharedFloat(time.time() - lag) callback = Mock() @@ -31,7 +33,7 @@ def callback_with_event() -> None: # and lag_tracker = LagTracker( errors_queue=errors_queue, - operations_queue=operations_queue, + sequence_tracker=sequence_tracker, last_ack_timestamp=last_ack_timestamp, async_lag_threshold=async_lag_threshold, on_async_lag_callback=callback_with_event, @@ -60,14 +62,15 @@ def test__lag_tracker__not_called(): # and errors_queue = Mock() - operations_queue = Mock(last_timestamp=time.time()) + sequence_tracker = SequenceTracker() + sequence_tracker.update_sequence_id(1) # This will set last_timestamp to current time last_ack_timestamp = SharedFloat(time.time() - lag) callback = Mock() # and lag_tracker = LagTracker( errors_queue=errors_queue, - operations_queue=operations_queue, + sequence_tracker=sequence_tracker, last_ack_timestamp=last_ack_timestamp, async_lag_threshold=async_lag_threshold, on_async_lag_callback=callback, diff --git a/tests/unit/test_metadata_splitter.py b/tests/unit/test_metadata_splitter.py index 462aa860..37cf8391 100644 --- a/tests/unit/test_metadata_splitter.py +++ b/tests/unit/test_metadata_splitter.py @@ -17,7 +17,6 @@ UpdateRunSnapshot, Value, ) -from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation from pytest import mark from neptune_scale.api.metrics import Metrics @@ -43,10 +42,9 @@ def test_empty(): # then assert len(result) == 1 - operation, metadata_size = result[0] + operation = result[0] expected_update = UpdateRunSnapshot(timestamp=Timestamp(seconds=1722341532, nanos=21934)) - assert operation == RunOperation(project="workspace/project", run_id="run_id", update=expected_update) - assert metadata_size == expected_update.ByteSize() + assert operation == expected_update @freeze_time("2024-07-30 12:12:12.000022") @@ -74,7 +72,7 @@ def test_configs(): # then assert len(result) == 1 - operation, metadata_size = result[0] + operation = result[0] expected_update = UpdateRunSnapshot( timestamp=Timestamp(seconds=1722341532, nanos=21934), assign={ @@ -86,9 +84,7 @@ def test_configs(): "some/tags": Value(string_set=StringSet(values={"tag1", "tag2"})), }, ) - assert operation == RunOperation(project="workspace/project", run_id="run_id", update=expected_update) - assert metadata_size >= expected_update.ByteSize() - assert metadata_size < operation.ByteSize() + assert operation == expected_update @freeze_time("2024-07-30 12:12:12.000022") @@ -150,7 +146,7 @@ def test_metrics(preview, preview_completion, expected_preview_proto): # then assert len(result) == 1 - operation, metadata_size = result[0] + operation = result[0] expected_update = UpdateRunSnapshot( step=Step(whole=1, micro=0), timestamp=Timestamp(seconds=1722341532, nanos=21934), @@ -159,9 +155,7 @@ def test_metrics(preview, preview_completion, expected_preview_proto): "some/metric": Value(float64=3.14), }, ) - assert operation == RunOperation(project="workspace/project", run_id="run_id", update=expected_update) - assert metadata_size >= expected_update.ByteSize() - assert metadata_size < operation.ByteSize() + assert operation == expected_update @freeze_time("2024-07-30 12:12:12.000022") @@ -188,7 +182,7 @@ def test_tags(): # then assert len(result) == 1 - operation, metadata_size = result[0] + operation = result[0] expected_update = UpdateRunSnapshot( timestamp=Timestamp(seconds=1722341532, nanos=21934), modify_sets={ @@ -206,9 +200,7 @@ def test_tags(): ), }, ) - assert operation == RunOperation(project="workspace/project", run_id="run_id", update=expected_update) - assert metadata_size >= expected_update.ByteSize() - assert metadata_size < operation.ByteSize() + assert operation == expected_update @freeze_time("2024-07-30 12:12:12.000022") @@ -244,19 +236,17 @@ def test_splitting(): assert len(result) > 1 # Every message should be smaller than max_size - assert all(len(op.SerializeToString()) <= max_size for op, _ in result) + assert all(len(op.SerializeToString()) <= max_size for op in result) # Common metadata - assert all(op.project == "workspace/project" for op, _ in result) - assert all(op.run_id == "run_id" for op, _ in result) - assert all(op.update.step.whole == 1 for op, _ in result) - assert all(op.update.preview.is_preview if len(op.update.append) > 0 else True for op, _ in result) - assert all(op.update.timestamp == Timestamp(seconds=1722341532, nanos=21934) for op, _ in result) + assert all(op.step.whole == 1 for op in result) + assert all(op.preview.is_preview if len(op.append) > 0 else True for op in result) + assert all(op.timestamp == Timestamp(seconds=1722341532, nanos=21934) for op in result) # Check if all metrics, configs and tags are present in the result - assert sorted([key for op, _ in result for key in op.update.append.keys()]) == sorted(list(metrics.data.keys())) - assert sorted([key for op, _ in result for key in op.update.assign.keys()]) == sorted(list(configs.keys())) - assert sorted([key for op, _ in result for key in op.update.modify_sets.keys()]) == sorted( + assert sorted([key for op in result for key in op.append.keys()]) == sorted(list(metrics.data.keys())) + assert sorted([key for op in result for key in op.assign.keys()]) == sorted(list(configs.keys())) + assert sorted([key for op in result for key in op.modify_sets.keys()]) == sorted( list(add_tags.keys()) + list(remove_tags.keys()) ) @@ -290,23 +280,19 @@ def test_split_large_tags(): assert len(result) > 1 # Every message should be smaller than max_size - assert all(len(op.SerializeToString()) <= max_size for op, _ in result) + assert all(len(op.SerializeToString()) <= max_size for op in result) # Common metadata - assert all(op.project == "workspace/project" for op, _ in result) - assert all(op.run_id == "run_id" for op, _ in result) - assert all(op.update.timestamp == Timestamp(seconds=1722341532, nanos=21934) for op, _ in result) + assert all(op.timestamp == Timestamp(seconds=1722341532, nanos=21934) for op in result) # Check if all StringSet values are split correctly - assert {key for op, _ in result for key in op.update.modify_sets.keys()} == set( + assert {key for op in result for key in op.modify_sets.keys()} == set( list(add_tags.keys()) + list(remove_tags.keys()) ) # Check if all tags are present in the result - assert {tag for op, _ in result for tag in op.update.modify_sets["add/tag"].string.values.keys()} == add_tags[ - "add/tag" - ] - assert {tag for op, _ in result for tag in op.update.modify_sets["remove/tag"].string.values.keys()} == remove_tags[ + assert {tag for op in result for tag in op.modify_sets["add/tag"].string.values.keys()} == add_tags["add/tag"] + assert {tag for op in result for tag in op.modify_sets["remove/tag"].string.values.keys()} == remove_tags[ "remove/tag" ] @@ -364,8 +350,8 @@ def test_skip_non_finite_float_metrics(value, caplog): # then assert len(result) == 1 - op, _ = result[0] - assert not op.update.assign + operation = result[0] + assert not operation.assign assert "Skipping a non-finite value" in caplog.text assert "bad-metric" in caplog.text diff --git a/tests/unit/test_operations_queue.py b/tests/unit/test_operations_queue.py deleted file mode 100644 index 3a25661a..00000000 --- a/tests/unit/test_operations_queue.py +++ /dev/null @@ -1,230 +0,0 @@ -import logging -import threading -import time -from time import monotonic - -import pytest -from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( - UpdateRunSnapshot, - Value, -) -from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation - -from neptune_scale.exceptions import NeptuneUnableToLogData -from neptune_scale.sync.operations_queue import OperationsQueue -from neptune_scale.util import envs - - -def test__enqueue(): - # given - lock = threading.RLock() - queue = OperationsQueue(lock=lock, max_size=0) - - # and - operation = RunOperation() - - # when - queue.enqueue(operation=operation, size=0) - - # then - assert queue._sequence_id == 1 - - # when - queue.enqueue(operation=operation, size=0) - - # then - assert queue._sequence_id == 2 - - -def test_drop_on_max_element_size_exceeded(monkeypatch, caplog): - monkeypatch.setenv(envs.LOG_FAILURE_ACTION, "drop") - - # given - lock = threading.RLock() - queue = OperationsQueue(lock=lock, max_size=1) - - # and - snapshot = UpdateRunSnapshot(assign={f"key_{i}": Value(string=("a" * 1024)) for i in range(1024)}) - operation = RunOperation(update=snapshot) - - # then - with caplog.at_level(logging.ERROR, logger="neptune"): - queue.enqueue(operation=operation, size=snapshot.ByteSize()) - assert len(caplog.records) == 1 - assert "Operation size exceeds the maximum allowed size" in caplog.text - - -def test_raise_on_max_element_size_exceeded(monkeypatch, caplog): - monkeypatch.setenv(envs.LOG_FAILURE_ACTION, "raise") - - lock = threading.RLock() - queue = OperationsQueue(lock=lock, max_size=1) - - snapshot = UpdateRunSnapshot(assign={f"key_{i}": Value(string=("a" * 1024)) for i in range(1024)}) - operation = RunOperation(update=snapshot) - - with caplog.at_level(logging.ERROR, logger="neptune"), pytest.raises(NeptuneUnableToLogData) as exc: - queue.enqueue(operation=operation, size=snapshot.ByteSize()) - - assert not caplog.records, "No errors should be logged" - assert exc.match("Operation size exceeds the maximum allowed size") - - -def test_invalid_log_failure_action(monkeypatch): - monkeypatch.setenv(envs.LOG_FAILURE_ACTION, "invalid") - with pytest.raises(ValueError) as exc: - OperationsQueue(lock=threading.RLock(), max_size=1) - exc.match(envs.LOG_FAILURE_ACTION) - - -def test_negative_blocking_time(monkeypatch): - monkeypatch.setenv(envs.LOG_MAX_BLOCKING_TIME_SECONDS, "-1") - with pytest.raises(ValueError) as exc: - OperationsQueue(lock=threading.RLock(), max_size=1) - exc.match(f"{envs.LOG_MAX_BLOCKING_TIME_SECONDS}.* non-negative") - - -def test_invalid_blocking_time(monkeypatch): - monkeypatch.setenv(envs.LOG_MAX_BLOCKING_TIME_SECONDS, "invalid") - with pytest.raises(ValueError) as exc: - OperationsQueue(lock=threading.RLock(), max_size=1) - exc.match(f"{envs.LOG_MAX_BLOCKING_TIME_SECONDS}.* must be an integer") - - -def _get_delayed(queue): - """Used in a thread to consume an element from the queue after a fixed delay""" - time.sleep(1.0) - queue.queue.get(timeout=0.5) - - -def test_enqueue_blocking_timing(monkeypatch): - """Test the blocking behaviour of the queue when depending on whether it's full or empty. - - Note that when checking instead of checking if the elapsed time is >= 2.0, - we apply some tolerance (check for > 1.9 instead), as the actual time blocking - on the queue might vary slightly, in particular we can return tiny fractions of a second earlier - """ - - monkeypatch.setenv(envs.LOG_FAILURE_ACTION, "drop") - monkeypatch.setenv(envs.LOG_MAX_BLOCKING_TIME_SECONDS, "2") - - queue = OperationsQueue(lock=threading.RLock(), max_size=1) - - # Queue empty: enqueue should not block - t0 = monotonic() - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert monotonic() - t0 < 0.5, "enqueue() on an empty queue should not block" - - # Queue full: block on first attempt to enqueue an item - t0 = monotonic() - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert monotonic() - t0 > 1.9, "enqueue() on a full queue should block" - - # Queue full: don't block on further failed attempts to enqueue items - t0 = monotonic() - for _ in range(5): - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert monotonic() - t0 < 0.5, "enqueue() on a full queue should not block after the previous call failed" - - # Queue empty again: enqueue should not block - queue.queue.get(timeout=0.5) - t0 = monotonic() - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert monotonic() - t0 < 0.5, "enqueue() on an empty queue should not block" - - # Start a thread that consumes an element from the queue while we're blocked on enqueue() - thread = threading.Thread(target=_get_delayed, args=(queue,), daemon=True) - thread.start() - - # Queue initially full, but is emptied during wait: enqueue should block but only until there is free space - t0 = monotonic() - queue.enqueue(operation=RunOperation(), size=10, key="key") - elapsed = monotonic() - t0 - assert elapsed > 0.5, "enqueue() on a full queue should block" - assert elapsed < 2.0, "Waiting on the queue should be interrupted once there is free capacity" - - # Queue full again: we should block - t0 = monotonic() - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert monotonic() - t0 > 1.9, "enqueue() on a full queue should block" - - thread.join() - - -def test_enqueue_drop_on_queue_full(monkeypatch, caplog): - """Test the behaviour of enqueue() with the "drop" action on a full queue""" - - monkeypatch.setenv(envs.LOG_FAILURE_ACTION, "drop") - monkeypatch.setenv(envs.LOG_MAX_BLOCKING_TIME_SECONDS, "2") - caplog.set_level(logging.ERROR, logger="neptune") - - queue = OperationsQueue(lock=threading.RLock(), max_size=1) - - # Queue empty: enqueue must succeed - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert not caplog.records, "enqueue() must succeed on an empty queue" - - # Queue full: drop subsequent items - for _ in range(3): - queue.enqueue(operation=RunOperation(), size=10, key="key") - - assert len(caplog.records) == 3, "An error should be logged for each failed enqueue()" - for rec in caplog.records: - assert "queue is full" in rec.message - - # Queue empty: enqueue must succeed - queue.queue.get(timeout=0.5) - caplog.clear() - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert not caplog.records, "enqueue() must succeed on an empty queue" - - # Start a thread that consumes an element from the queue while we're blocked on enqueue() - thread = threading.Thread(target=_get_delayed, args=(queue,)) - thread.start() - - # Queue initially full, but is emptied during wait: enqueue should succeed - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert not caplog.records, "enqueue() must succeed when queue is emptied during wait" - - # Queue full: drop the item - queue.enqueue(operation=RunOperation(), size=10, key="key") - assert len(caplog.records) == 1, "a single error should be logged" - assert "queue is full" in caplog.text - - thread.join() - - -def test_enqueue_raise_on_queue_full(monkeypatch): - """Test the behaviour of enqueue() with the "drop" action on a full queue""" - - monkeypatch.setenv(envs.LOG_FAILURE_ACTION, "raise") - monkeypatch.setenv(envs.LOG_MAX_BLOCKING_TIME_SECONDS, "2") - - queue = OperationsQueue(lock=threading.RLock(), max_size=1) - - # Queue empty: enqueue must succeed - queue.enqueue(operation=RunOperation(), size=10, key="key") - - # Queue full: drop subsequent items - for x in range(3): - with pytest.raises(NeptuneUnableToLogData) as exc: - queue.enqueue(operation=RunOperation(), size=10, key="key") - exc.match("queue is full") - - # Queue empty: enqueue must succeed - queue.queue.get(timeout=0.5) - queue.enqueue(operation=RunOperation(), size=10, key="key") - - # Start a thread that consumes an element from the queue while we're blocked on enqueue() - thread = threading.Thread(target=_get_delayed, args=(queue,)) - thread.start() - - # Queue initially full, but is emptied during wait: enqueue() should succeed - queue.enqueue(operation=RunOperation(), size=10, key="key") - - # Queue full: drop the item - with pytest.raises(NeptuneUnableToLogData) as exc: - queue.enqueue(operation=RunOperation(), size=10, key="key") - exc.match("queue is full") - - thread.join() diff --git a/tests/unit/test_run_resume.py b/tests/unit/test_run_resume.py new file mode 100644 index 00000000..3652f21f --- /dev/null +++ b/tests/unit/test_run_resume.py @@ -0,0 +1,161 @@ +import uuid + +import pytest + +from neptune_scale import Run +from neptune_scale.api.run import _validate_existing_db +from neptune_scale.exceptions import ( + NeptuneConflictingDataInLocalStorage, + NeptuneLocalStorageInUnsupportedVersion, +) +from neptune_scale.sync.operations_repository import Metadata + + +def test_resume_false_with_matching_fork_point(api_token, caplog): + project = "workspace/project" + run_id = str(uuid.uuid4()) + fork_run_id = "parent-run" + fork_step = 5 + + # First create a run to set up the metadata + with Run( + project=project, + api_token=api_token, + run_id=run_id, + mode="disabled", + fork_run_id=fork_run_id, + fork_step=fork_step, + ): + pass + + # Then try to create the same run again without resume + with caplog.at_level("WARNING"): + with Run( + project=project, + api_token=api_token, + run_id=run_id, + resume=False, + mode="disabled", + fork_run_id=fork_run_id, + fork_step=fork_step, + ): + pass + assert "Run already exists in local storage" in caplog.text + + # Then try to use the same run_id with a different project + with Run( + project=project + "2", + api_token=api_token, + run_id=run_id, + resume=False, + mode="disabled", + fork_run_id=fork_run_id, + fork_step=fork_step, + ): + pass + + +def test_resume_false_with_conflicting_fork_point( + api_token, +): + project = "workspace/project" + run_id = str(uuid.uuid4()) + + # First create a run with one fork point + with Run( + project=project, api_token=api_token, run_id=run_id, mode="disabled", fork_run_id="parent-run-1", fork_step=5 + ): + pass + + # Then try to create the same run but with a different fork point + with pytest.raises(NeptuneConflictingDataInLocalStorage): + Run( + project=project, + api_token=api_token, + run_id=run_id, + resume=False, + mode="disabled", + fork_run_id="parent-run-2", + fork_step=10, + ) + + # Then try to create the same run but with a different run_id + with Run( + project=project + "2", + api_token=api_token, + run_id=run_id, + resume=False, + mode="disabled", + fork_run_id="parent-run-2", + fork_step=10, + ): + pass + + +def test_resume_true( + api_token, +): + project = "workspace/project" + run_id = str(uuid.uuid4()) + fork_run_id = "parent-run" + fork_step = 5.0 + + # First create a run to set up the metadata + with Run( + project=project, + api_token=api_token, + run_id=run_id, + mode="disabled", + fork_run_id=fork_run_id, + fork_step=fork_step, + ): + pass + + # Then resume the same run with matching fork point + + with Run( + project=project, + api_token=api_token, + run_id=run_id, + resume=True, + mode="disabled", + ): + pass + + +def test_resume_true_without_fork_point( + api_token, +): + project = "workspace/project" + run_id = str(uuid.uuid4()) + + # First create a run with one fork point + with Run(project=project, api_token=api_token, run_id=run_id, mode="disabled"): + pass + + # Then resume the run with a different fork point + with Run(project=project, api_token=api_token, run_id=run_id, resume=True, mode="disabled"): + pass + + +def test_resume_true_with_no_metadata( + api_token, +): + project = "workspace/project" + run_id = str(uuid.uuid4()) + + # Create a run with resume=True but no pre-existing metadata + with Run(project=project, api_token=api_token, run_id=run_id, resume=True, mode="disabled"): + pass + + +def test_unsupported_version_error(): + # Given - create a metadata with an unsupported version + metadata = Metadata( + version="unsupported_version", project="project", run_id="run_id", parent_run_id="parent_run_id", fork_step=1.0 + ) + + with pytest.raises(NeptuneLocalStorageInUnsupportedVersion): + _validate_existing_db( + metadata, resume=False, project="project", run_id="run_id", fork_run_id="parent_run_id", fork_step=1.0 + ) diff --git a/tests/unit/test_sync_process.py b/tests/unit/test_sync_process.py index b1b30737..3c8b545c 100644 --- a/tests/unit/test_sync_process.py +++ b/tests/unit/test_sync_process.py @@ -1,15 +1,18 @@ -import queue +import itertools +import os +import tempfile import time +from pathlib import Path from unittest.mock import Mock import neptune_api.proto.neptune_pb.ingest.v1.ingest_pb2 as ingest_pb2 import pytest +from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import Run as CreateRun from neptune_api.proto.neptune_pb.ingest.v1.common_pb2 import ( UpdateRunSnapshot, Value, ) from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import SubmitResponse -from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation from neptune_scale import NeptuneScaleWarning from neptune_scale.exceptions import ( @@ -17,16 +20,25 @@ NeptuneSynchronizationStopped, NeptuneUnexpectedError, ) -from neptune_scale.sync.queue_element import ( - BatchedOperations, - SingleOperation, +from neptune_scale.sync.errors_tracking import ErrorsQueue +from neptune_scale.sync.operations_repository import ( + Metadata, + Operation, + OperationsRepository, + OperationType, + SequenceId, ) +from neptune_scale.sync.parameters import MAX_REQUEST_SIZE_BYTES from neptune_scale.sync.sync_process import ( + PeekableQueue, SenderThread, + StatusTrackingElement, code_to_exception, ) from neptune_scale.util.shared_var import SharedInt +metadata = Metadata(project="project", run_id="run_id", version="v1") + def response(request_ids: list[str], status_code: int = 200): body = SubmitResponse(request_ids=request_ids, request_id=request_ids[-1] if request_ids else None) @@ -35,20 +47,34 @@ def response(request_ids: list[str], status_code: int = 200): def single_operation(update: UpdateRunSnapshot, sequence_id): - operation = RunOperation(update=update) - return SingleOperation( - sequence_id=sequence_id, - timestamp=time.process_time(), - operation=operation.SerializeToString(), - is_batchable=True, - metadata_size=update.ByteSize(), - batch_key=None, + return Operation( + sequence_id=SequenceId(sequence_id), + timestamp=int(time.time() * 1000), + operation_type=OperationType.UPDATE_SNAPSHOT, + operation=update, + operation_size_bytes=update.ByteSize(), ) -def test_sender_thread_work_finishes_when_queue_empty(): +@pytest.fixture +def operations_repository_mock(): + repo = Mock() + repo.get_metadata.side_effect = [metadata] + return repo + + +@pytest.fixture +def operations_repo(): + with tempfile.TemporaryDirectory() as temp_dir: + repo = OperationsRepository(db_path=Path(os.path.join(temp_dir, "test_operations.db"))) + repo.init_db() + repo.save_metadata("project", "run_id") + yield repo + repo.close() + + +def test_sender_thread_work_finishes_when_queue_empty(operations_repository_mock): # given - operations_queue = Mock() status_tracking_queue = Mock() errors_queue = Mock() last_queue_seq = SharedInt(initial_value=0) @@ -56,7 +82,7 @@ def test_sender_thread_work_finishes_when_queue_empty(): sender_thread = SenderThread( api_token="", family="", - operations_queue=operations_queue, + operations_repository=operations_repository_mock, status_tracking_queue=status_tracking_queue, errors_queue=errors_queue, last_queued_seq=last_queue_seq, @@ -65,7 +91,7 @@ def test_sender_thread_work_finishes_when_queue_empty(): sender_thread._backend = backend # and - operations_queue.get.side_effect = queue.Empty + operations_repository_mock.get_operations.side_effect = [[]] # when sender_thread.work() @@ -74,9 +100,9 @@ def test_sender_thread_work_finishes_when_queue_empty(): assert True -def test_sender_thread_processes_single_element(): +def test_sender_thread_processes_single_element(operations_repository_mock): # given - operations_queue = Mock() + status_tracking_queue = Mock() errors_queue = Mock() last_queue_seq = SharedInt(initial_value=0) @@ -84,7 +110,7 @@ def test_sender_thread_processes_single_element(): sender_thread = SenderThread( api_token="", family="", - operations_queue=operations_queue, + operations_repository=operations_repository_mock, status_tracking_queue=status_tracking_queue, errors_queue=errors_queue, last_queued_seq=last_queue_seq, @@ -95,10 +121,7 @@ def test_sender_thread_processes_single_element(): # and update = UpdateRunSnapshot(assign={"key": Value(string="a")}) element = single_operation(update, sequence_id=2) - operations_queue.get.side_effect = [ - BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation), - queue.Empty, - ] + operations_repository_mock.get_operations.side_effect = [[element], []] # and backend.submit.side_effect = [response(["1"])] @@ -110,9 +133,8 @@ def test_sender_thread_processes_single_element(): assert backend.submit.call_count == 1 -def test_sender_thread_processes_element_on_single_retryable_error(): +def test_sender_thread_processes_element_on_single_retryable_error(operations_repository_mock): # given - operations_queue = Mock() status_tracking_queue = Mock() errors_queue = Mock() last_queue_seq = SharedInt(initial_value=0) @@ -120,7 +142,7 @@ def test_sender_thread_processes_element_on_single_retryable_error(): sender_thread = SenderThread( api_token="", family="", - operations_queue=operations_queue, + operations_repository=operations_repository_mock, status_tracking_queue=status_tracking_queue, errors_queue=errors_queue, last_queued_seq=last_queue_seq, @@ -131,10 +153,7 @@ def test_sender_thread_processes_element_on_single_retryable_error(): # and update = UpdateRunSnapshot(assign={"key": Value(string="a")}) element = single_operation(update, sequence_id=2) - operations_queue.get.side_effect = [ - BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation), - queue.Empty, - ] + operations_repository_mock.get_operations.side_effect = [[element], []] # and backend.submit.side_effect = [ @@ -151,7 +170,7 @@ def test_sender_thread_processes_element_on_single_retryable_error(): def test_sender_thread_fails_on_regular_error(): # given - operations_queue = Mock() + operations_repository_mock = Mock() status_tracking_queue = Mock() errors_queue = Mock() last_queue_seq = SharedInt(initial_value=0) @@ -159,21 +178,19 @@ def test_sender_thread_fails_on_regular_error(): sender_thread = SenderThread( api_token="", family="", - operations_queue=operations_queue, + operations_repository=operations_repository_mock, status_tracking_queue=status_tracking_queue, errors_queue=errors_queue, last_queued_seq=last_queue_seq, mode="disabled", ) sender_thread._backend = backend + operations_repository_mock.get_metadata.side_effect = [metadata] # and update = UpdateRunSnapshot(assign={"key": Value(string="a")}) element = single_operation(update, sequence_id=2) - operations_queue.get.side_effect = [ - BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation), - queue.Empty, - ] + operations_repository_mock.get_operations.side_effect = [[element], []] # and backend.submit.side_effect = [ @@ -188,9 +205,8 @@ def test_sender_thread_fails_on_regular_error(): errors_queue.put.assert_called_once() -def test_sender_thread_processes_element_on_429_and_408_http_statuses(): +def test_sender_thread_processes_element_on_429_and_408_http_statuses(operations_repository_mock): # given - operations_queue = Mock() status_tracking_queue = Mock() errors_queue = Mock() last_queue_seq = SharedInt(initial_value=0) @@ -198,7 +214,7 @@ def test_sender_thread_processes_element_on_429_and_408_http_statuses(): sender_thread = SenderThread( api_token="", family="", - operations_queue=operations_queue, + operations_repository=operations_repository_mock, status_tracking_queue=status_tracking_queue, errors_queue=errors_queue, last_queued_seq=last_queue_seq, @@ -209,10 +225,7 @@ def test_sender_thread_processes_element_on_429_and_408_http_statuses(): # and update = UpdateRunSnapshot(assign={"key": Value(string="a")}) element = single_operation(update, sequence_id=2) - operations_queue.get.side_effect = [ - BatchedOperations(sequence_id=element.sequence_id, timestamp=element.timestamp, operation=element.operation), - queue.Empty, - ] + operations_repository_mock.get_operations.side_effect = [[element], []] # and backend.submit.side_effect = [ @@ -228,6 +241,121 @@ def test_sender_thread_processes_element_on_429_and_408_http_statuses(): assert backend.submit.call_count == 3 +def test_sender_thread_processes_elements_with_multiple_operations_in_batch(operations_repo): + status_tracking_queue = PeekableQueue() + errors_queue = ErrorsQueue() + last_queue_seq = SharedInt(initial_value=0) + backend = Mock() + sender_thread = SenderThread( + api_token="a" * 10, + family="test-family", + operations_repository=operations_repo, + status_tracking_queue=status_tracking_queue, + errors_queue=errors_queue, + last_queued_seq=last_queue_seq, + mode="disabled", + ) + sender_thread._backend = backend + backend.submit.side_effect = itertools.repeat(response(["a"], status_code=200)) + + # and + updates = [] + for i in range(10): + update = UpdateRunSnapshot(assign={"key": Value(string=f"a{i}")}) + updates.append(update) + last_sequence_id = operations_repo.save_update_run_snapshots(updates) + # when + sender_thread.work() + + # then + assert backend.submit.call_count == 1 + + tracking: list[StatusTrackingElement] = status_tracking_queue.peek(10) # type: ignore + assert len(tracking) == 1 + assert tracking[0].sequence_id == last_sequence_id + + assert operations_repo.get_operations(MAX_REQUEST_SIZE_BYTES) == [] + assert last_queue_seq.value == last_sequence_id + + +def test_sender_thread_processes_elements_with_multiple_operations_in_batches(operations_repo): + status_tracking_queue = PeekableQueue() + errors_queue = ErrorsQueue() + last_queue_seq = SharedInt(initial_value=0) + backend = Mock() + sender_thread = SenderThread( + api_token="a" * 10, + family="test-family", + operations_repository=operations_repo, + status_tracking_queue=status_tracking_queue, + errors_queue=errors_queue, + last_queued_seq=last_queue_seq, + mode="disabled", + ) + sender_thread._backend = backend + backend.submit.side_effect = itertools.repeat(response(["a"], status_code=200)) + + # and + updates = [UpdateRunSnapshot(assign={"key": Value(string=f"a{i}")}) for i in range(10)] + + operations_repo.save_create_run(CreateRun(family="test-run-id", experiment_id="Test Run")) + + last_sequence_id = operations_repo.save_update_run_snapshots(updates) + operations_repo.save_create_run(CreateRun(family="test-run-id", experiment_id="Test Run")) + + last_sequence_id = operations_repo.save_update_run_snapshots(updates) + + # when + sender_thread.work() + + # then + assert backend.submit.call_count == 4 + + tracking: list[StatusTrackingElement] = status_tracking_queue.peek(10) + assert len(tracking) == 4 + assert tracking[-1].sequence_id == last_sequence_id + + assert operations_repo.get_operations(MAX_REQUEST_SIZE_BYTES) == [] + assert last_queue_seq.value == last_sequence_id + + +def test_sender_thread_processes_big_operations_in_batches(operations_repo): + status_tracking_queue = PeekableQueue() + errors_queue = ErrorsQueue() + last_queue_seq = SharedInt(initial_value=0) + backend = Mock() + sender_thread = SenderThread( + api_token="a" * 10, + family="test-family", + operations_repository=operations_repo, + status_tracking_queue=status_tracking_queue, + errors_queue=errors_queue, + last_queued_seq=last_queue_seq, + mode="disabled", + ) + sender_thread._backend = backend + backend.submit.side_effect = itertools.repeat(response(["a"], status_code=200)) + + # and + operations_repo.save_create_run(CreateRun(family="test-run-id", experiment_id="Test Run")) + + updates = [UpdateRunSnapshot(assign={"key": Value(string="a" * 1024 * 1024)})] * 30 # 30MB + last_sequence_id = operations_repo.save_update_run_snapshots(updates) + + # when + sender_thread.work() + + # then + assert backend.submit.call_count == 3 + + tracking: list[StatusTrackingElement] = status_tracking_queue.peek(10) + assert len(tracking) == 3 + assert tracking[-1].sequence_id == last_sequence_id + + assert operations_repo.get_operations(MAX_REQUEST_SIZE_BYTES) == [] + assert last_queue_seq.value == last_sequence_id + + @pytest.mark.parametrize( "code", ingest_pb2.IngestCode.DESCRIPTOR.values_by_number.keys(),