diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ff08c31e..b9c6e757 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: args: [ --config-file, pyproject.toml ] pass_filenames: false additional_dependencies: - - neptune-api==0.4.0 + - neptune-api==0.6.0 - more-itertools - backoff default_language_version: diff --git a/pyproject.toml b/pyproject.toml index a6897830..4850781f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ pattern = "default-unprefixed" [tool.poetry.dependencies] python = "^3.8" -neptune-api = "0.4.0" +neptune-api = "0.6.0" more-itertools = "^10.0.0" psutil = "^5.0.0" backoff = "^2.0.0" @@ -77,10 +77,10 @@ force_grid_wrap = 2 [tool.ruff] line-length = 120 target-version = "py38" -ignore = ["UP006", "UP007"] [tool.ruff.lint] select = ["F", "UP"] +ignore = ["UP006", "UP007"] [tool.mypy] files = 'src/neptune_scale' diff --git a/src/neptune_scale/__init__.py b/src/neptune_scale/__init__.py index 12964655..558282b6 100644 --- a/src/neptune_scale/__init__.py +++ b/src/neptune_scale/__init__.py @@ -16,6 +16,7 @@ from multiprocessing.sharedctypes import Synchronized from multiprocessing.synchronize import Condition as ConditionT from typing import ( + Any, Callable, Dict, List, @@ -60,6 +61,7 @@ MAX_FAMILY_LENGTH, MAX_QUEUE_SIZE, MAX_RUN_ID_LENGTH, + MINIMAL_WAIT_FOR_ACK_SLEEP_TIME, MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, STOP_MESSAGE_FREQUENCY, ) @@ -172,15 +174,23 @@ def __init__( max_queue_size_exceeded_callback=max_queue_size_exceeded_callback, on_network_error_callback=on_network_error_callback, ) + self._last_put_seq: Synchronized[int] = multiprocessing.Value("i", -1) self._last_put_seq_wait: ConditionT = multiprocessing.Condition() + + self._last_ack_seq: Synchronized[int] = multiprocessing.Value("i", -1) + self._last_ack_seq_wait: ConditionT = multiprocessing.Condition() + self._sync_process = SyncProcess( + project=self._project, family=self._family, operations_queue=self._operations_queue.queue, errors_queue=self._errors_queue, api_token=input_api_token, last_put_seq=self._last_put_seq, last_put_seq_wait=self._last_put_seq_wait, + last_ack_seq=self._last_ack_seq, + last_ack_seq_wait=self._last_ack_seq_wait, max_queue_size=max_queue_size, mode=mode, ) @@ -198,6 +208,7 @@ def __init__( from_run_id=from_run_id, from_step=from_step, ) + self.wait_for_processing(verbose=False) @property def resources(self) -> tuple[Resource, ...]: @@ -208,10 +219,9 @@ def resources(self) -> tuple[Resource, ...]: ) def _close(self) -> None: - # TODO: Change to wait for all operations to be processed with self._lock: if self._sync_process.is_alive(): - self.wait_for_submission() + self.wait_for_processing() self._sync_process.terminate() self._sync_process.join() @@ -320,49 +330,106 @@ def log( for operation in splitter: self._operations_queue.enqueue(operation=operation) - def wait_for_submission(self, timeout: Optional[float] = None) -> None: - """ - Waits until all metadata is submitted to Neptune. - """ - begin_time = time.time() - logger.info("Waiting for all operations to be processed") - if timeout is None: + def _wait( + self, + phrase: str, + sleep_time: float, + wait_condition: ConditionT, + external_value: Synchronized[int], + timeout: Optional[float] = None, + verbose: bool = True, + ) -> None: + if verbose: + logger.info(f"Waiting for all operations to be {phrase}") + + if timeout is None and verbose: logger.warning("No timeout specified. Waiting indefinitely") with self._lock: if not self._sync_process.is_alive(): - logger.warning("Sync process is not running") + if verbose: + logger.warning("Sync process is not running") return # No need to wait if the sync process is not running - sleep_time_wait = ( - min(MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, timeout) if timeout is not None else MINIMAL_WAIT_FOR_PUT_SLEEP_TIME - ) + begin_time = time.time() + wait_time = min(sleep_time, timeout) if timeout is not None else sleep_time last_queued_sequence_id = self._operations_queue.last_sequence_id - last_message_printed: Optional[float] = None + last_print_timestamp: Optional[float] = None + while True: - with self._last_put_seq_wait: - self._last_put_seq_wait.wait(timeout=sleep_time_wait) - value = self._last_put_seq.value + with wait_condition: + wait_condition.wait(timeout=wait_time) + value = external_value.value + if value == -1: if self._operations_queue.last_sequence_id != -1: - if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY: - last_message_printed = time.time() - logger.info( - "Waiting. No operations processed yet. Operations to sync: %s", - self._operations_queue.last_sequence_id + 1, - ) + last_print_timestamp = print_message( + f"Waiting. No operations were {phrase} yet. Operations to sync: %s", + self._operations_queue.last_sequence_id + 1, + last_print=last_print_timestamp, + verbose=verbose, + ) else: - if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY: - last_message_printed = time.time() - logger.info("Waiting. No operations processed yet") - else: - if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY: - last_message_printed = time.time() - logger.info( - "Waiting for remaining %d operation(s) to be synced", - last_queued_sequence_id - value + 1, + last_print_timestamp = print_message( + f"Waiting. No operations were {phrase} yet", + last_print=last_print_timestamp, + verbose=verbose, ) + else: + last_print_timestamp = print_message( + f"Waiting for remaining %d operation(s) to be {phrase}", + last_queued_sequence_id - value + 1, + last_print=last_print_timestamp, + verbose=verbose, + ) + + # Reaching the last queued sequence ID means that all operations were submitted if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout): break - logger.info("All operations processed") + if verbose: + logger.info(f"All operations were {phrase}") + + def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> None: + """ + Waits until all metadata is submitted to Neptune. + + Args: + timeout (float, optional): In seconds, the maximum time to wait for submission. + verbose (bool): If True (default), prints messages about the waiting process. + """ + self._wait( + phrase="submitted", + sleep_time=MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, + wait_condition=self._last_put_seq_wait, + external_value=self._last_put_seq, + timeout=timeout, + verbose=verbose, + ) + + def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> None: + """ + Waits until all metadata is processed by Neptune. + + Args: + timeout (float, optional): In seconds, the maximum time to wait for processing. + verbose (bool): If True (default), prints messages about the waiting process. + """ + self._wait( + phrase="processed", + sleep_time=MINIMAL_WAIT_FOR_ACK_SLEEP_TIME, + wait_condition=self._last_ack_seq_wait, + external_value=self._last_ack_seq, + timeout=timeout, + verbose=verbose, + ) + + +def print_message(msg: str, *args: Any, last_print: Optional[float] = None, verbose: bool = True) -> Optional[float]: + current_time = time.time() + + if verbose and (last_print is None or current_time - last_print > STOP_MESSAGE_FREQUENCY): + logger.info(msg, *args) + return current_time + + return last_print diff --git a/src/neptune_scale/api/api_client.py b/src/neptune_scale/api/api_client.py index 0d8cf301..37fc6c2f 100644 --- a/src/neptune_scale/api/api_client.py +++ b/src/neptune_scale/api/api_client.py @@ -15,14 +15,17 @@ # from __future__ import annotations -__all__ = ("HostedApiClient", "MockedApiClient", "ApiClient") +__all__ = ("HostedApiClient", "MockedApiClient", "ApiClient", "backend_factory") import abc import os import uuid from dataclasses import dataclass from http import HTTPStatus -from typing import Any +from typing import ( + Any, + Literal, +) from httpx import Timeout from neptune_api import ( @@ -30,15 +33,25 @@ Client, ) from neptune_api.api.backend import get_client_config -from neptune_api.api.data_ingestion import submit_operation +from neptune_api.api.data_ingestion import ( + check_request_status_bulk, + submit_operation, +) from neptune_api.auth_helpers import exchange_api_key from neptune_api.credentials import Credentials from neptune_api.models import ( ClientConfig, Error, ) -from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import RequestId +from neptune_api.proto.google_rpc.code_pb2 import Code +from neptune_api.proto.neptune_pb.ingest.v1.ingest_pb2 import IngestCode +from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import ( + BulkRequestStatus, + RequestId, + RequestIdList, +) from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation +from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus from neptune_api.types import Response from neptune_scale.core.components.abstract import Resource @@ -95,6 +108,9 @@ class ApiClient(Resource, abc.ABC): @abc.abstractmethod def submit(self, operation: RunOperation, family: str) -> Response[RequestId]: ... + @abc.abstractmethod + def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ... + class HostedApiClient(ApiClient): def __init__(self, api_token: str) -> None: @@ -112,6 +128,13 @@ def __init__(self, api_token: str) -> None: def submit(self, operation: RunOperation, family: str) -> Response[RequestId]: return submit_operation.sync_detailed(client=self._backend, body=operation, family=family) + def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: + return check_request_status_bulk.sync_detailed( + client=self._backend, + project_identifier=project, + body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]), + ) + def close(self) -> None: logger.debug("Closing API client") self._backend.__exit__() @@ -123,3 +146,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def submit(self, operation: RunOperation, family: str) -> Response[RequestId]: return Response(content=b"", parsed=RequestId(value=str(uuid.uuid4())), status_code=HTTPStatus.OK, headers={}) + + def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: + response_body = BulkRequestStatus( + statuses=list( + map( + lambda _: RequestStatus( + code_by_count=[RequestStatus.CodeByCount(count=1, code=Code.OK, detail=IngestCode.OK)] + ), + request_ids, + ) + ) + ) + return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={}) + + +def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient: + if mode == "disabled": + return MockedApiClient() + return HostedApiClient(api_token=api_token) diff --git a/src/neptune_scale/core/components/errors_tracking.py b/src/neptune_scale/core/components/errors_tracking.py index 18f82aa3..6059d51a 100644 --- a/src/neptune_scale/core/components/errors_tracking.py +++ b/src/neptune_scale/core/components/errors_tracking.py @@ -17,6 +17,7 @@ NeptuneConnectionLostError, NeptuneOperationsQueueMaxSizeExceeded, NeptuneScaleError, + NeptuneScaleWarning, NeptuneUnexpectedError, ) from neptune_scale.parameters import ERRORS_MONITOR_THREAD_SLEEP_TIME @@ -51,6 +52,10 @@ def default_max_queue_size_exceeded_callback(error: BaseException) -> None: logger.warning(error) +def default_warning_callback(error: BaseException) -> None: + logger.warning(error) + + class ErrorsMonitor(Daemon, Resource): def __init__( self, @@ -58,6 +63,7 @@ def __init__( max_queue_size_exceeded_callback: Optional[Callable[[BaseException], None]] = None, on_network_error_callback: Optional[Callable[[BaseException], None]] = None, on_error_callback: Optional[Callable[[BaseException], None]] = None, + on_warning_callback: Optional[Callable[[BaseException], None]] = None, ): super().__init__(name="ErrorsMonitor", sleep_time=ERRORS_MONITOR_THREAD_SLEEP_TIME) @@ -69,6 +75,7 @@ def __init__( on_network_error_callback or default_network_error_callback ) self._on_error_callback: Callable[[BaseException], None] = on_error_callback or default_error_callback + self._on_warning_callback: Callable[[BaseException], None] = on_warning_callback or default_warning_callback def get_next(self) -> Optional[BaseException]: try: @@ -82,6 +89,8 @@ def work(self) -> None: self._max_queue_size_exceeded_callback(error) elif isinstance(error, NeptuneConnectionLostError): self._non_network_error_callback(error) + elif isinstance(error, NeptuneScaleWarning): + self._on_warning_callback(error) elif isinstance(error, NeptuneScaleError): self._on_error_callback(error) else: diff --git a/src/neptune_scale/core/components/sync_process.py b/src/neptune_scale/core/components/sync_process.py index 34c4802f..5b8f99dd 100644 --- a/src/neptune_scale/core/components/sync_process.py +++ b/src/neptune_scale/core/components/sync_process.py @@ -4,6 +4,7 @@ import multiprocessing import queue +import threading from multiprocessing import ( Process, Queue, @@ -13,8 +14,14 @@ from typing import ( Any, Callable, + Dict, + Generic, + List, Literal, + NamedTuple, Optional, + Type, + TypeVar, ) import backoff @@ -26,14 +33,17 @@ UnableToRefreshTokenError, UnexpectedStatus, ) -from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import RequestId +from neptune_api.proto.google_rpc.code_pb2 import Code +from neptune_api.proto.neptune_pb.ingest.v1.ingest_pb2 import IngestCode +from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import ( + BulkRequestStatus, + RequestId, +) from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation -from neptune_api.types import Response from neptune_scale.api.api_client import ( ApiClient, - HostedApiClient, - MockedApiClient, + backend_factory, ) from neptune_scale.core.components.abstract import ( Resource, @@ -45,20 +55,109 @@ from neptune_scale.core.logger import logger from neptune_scale.exceptions import ( NeptuneConnectionLostError, + NeptuneFieldPathEmpty, + NeptuneFieldPathExceedsSizeLimit, + NeptuneFieldPathInvalid, + NeptuneFieldPathNonWritable, + NeptuneFieldTypeConflicting, + NeptuneFieldTypeUnsupported, + NeptuneFloatValueNanInfUnsupported, NeptuneInvalidCredentialsError, NeptuneOperationsQueueMaxSizeExceeded, + NeptuneProjectInvalidName, + NeptuneProjectNotFound, NeptuneRetryableError, + NeptuneRunConflicting, + NeptuneRunDuplicate, + NeptuneRunForkParentNotFound, + NeptuneRunInvalidCreationParameters, + NeptuneRunNotFound, + NeptuneSeriesPointDuplicate, + NeptuneSeriesStepNonIncreasing, + NeptuneSeriesStepNotAfterForkPoint, + NeptuneSeriesTimestampDecreasing, + NeptuneStringSetExceedsSizeLimit, + NeptuneStringValueExceedsSizeLimit, NeptuneUnableToAuthenticateError, NeptuneUnauthorizedError, + NeptuneUnexpectedError, ) from neptune_scale.parameters import ( EXTERNAL_TO_INTERNAL_THREAD_SLEEP_TIME, MAX_QUEUE_SIZE, + MAX_REQUESTS_STATUS_BATCH_SIZE, OPERATION_TIMEOUT, SHUTDOWN_TIMEOUT, + STATUS_TRACKING_THREAD_SLEEP_TIME, SYNC_THREAD_SLEEP_TIME, ) +T = TypeVar("T") + + +CODE_TO_ERROR: Dict[IngestCode.ValueType, Optional[Type[Exception]]] = { + IngestCode.OK: None, + IngestCode.PROJECT_NOT_FOUND: NeptuneProjectNotFound, + IngestCode.PROJECT_INVALID_NAME: NeptuneProjectInvalidName, + IngestCode.RUN_NOT_FOUND: NeptuneRunNotFound, + IngestCode.RUN_DUPLICATE: NeptuneRunDuplicate, + IngestCode.RUN_CONFLICTING: NeptuneRunConflicting, + IngestCode.RUN_FORK_PARENT_NOT_FOUND: NeptuneRunForkParentNotFound, + IngestCode.RUN_INVALID_CREATION_PARAMETERS: NeptuneRunInvalidCreationParameters, + IngestCode.FIELD_PATH_EXCEEDS_SIZE_LIMIT: NeptuneFieldPathExceedsSizeLimit, + IngestCode.FIELD_PATH_EMPTY: NeptuneFieldPathEmpty, + IngestCode.FIELD_PATH_INVALID: NeptuneFieldPathInvalid, + IngestCode.FIELD_PATH_NON_WRITABLE: NeptuneFieldPathNonWritable, + IngestCode.FIELD_TYPE_UNSUPPORTED: NeptuneFieldTypeUnsupported, + IngestCode.FIELD_TYPE_CONFLICTING: NeptuneFieldTypeConflicting, + IngestCode.SERIES_POINT_DUPLICATE: NeptuneSeriesPointDuplicate, + IngestCode.SERIES_STEP_NON_INCREASING: NeptuneSeriesStepNonIncreasing, + IngestCode.SERIES_STEP_NOT_AFTER_FORK_POINT: NeptuneSeriesStepNotAfterForkPoint, + IngestCode.SERIES_TIMESTAMP_DECREASING: NeptuneSeriesTimestampDecreasing, + IngestCode.FLOAT_VALUE_NAN_INF_UNSUPPORTED: NeptuneFloatValueNanInfUnsupported, + IngestCode.STRING_VALUE_EXCEEDS_SIZE_LIMIT: NeptuneStringValueExceedsSizeLimit, + IngestCode.STRING_SET_EXCEEDS_SIZE_LIMIT: NeptuneStringSetExceedsSizeLimit, +} + + +class StatusTrackingElement(NamedTuple): + sequence_id: int + request_id: str + + +def code_to_exception(code: IngestCode.ValueType) -> Optional[Type[Exception]]: + if code in CODE_TO_ERROR: + return CODE_TO_ERROR[code] + return NeptuneUnexpectedError + + +class PeekableQueue(Generic[T]): + def __init__(self) -> None: + self._lock: threading.RLock = threading.RLock() + self._queue: queue.Queue[T] = queue.Queue() + + def put(self, element: T) -> None: + with self._lock: + self._queue.put(element) + + def peek(self, max_size: int) -> Optional[List[T]]: + with self._lock: + size = self._queue.qsize() + if size == 0: + return None + + items = [] + for i in range(min(size, max_size)): + item = self._queue.queue[i] + items.append(item) + return items + + def commit(self, n: int) -> None: + with self._lock: + size = self._queue.qsize() + for _ in range(min(size, n)): + self._queue.get() + def with_api_errors_handling(func: Callable[..., Any]) -> Callable[..., Any]: def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -82,10 +181,13 @@ def __init__( operations_queue: Queue, errors_queue: ErrorsQueue, api_token: str, + project: str, family: str, mode: Literal["async", "disabled"], last_put_seq: Synchronized[int], last_put_seq_wait: Condition, + last_ack_seq: Synchronized[int], + last_ack_seq_wait: Condition, max_queue_size: int = MAX_QUEUE_SIZE, ) -> None: super().__init__(name="SyncProcess") @@ -93,21 +195,27 @@ def __init__( self._external_operations_queue: Queue[QueueElement] = operations_queue self._errors_queue: ErrorsQueue = errors_queue self._api_token: str = api_token + self._project: str = project self._family: str = family self._last_put_seq: Synchronized[int] = last_put_seq self._last_put_seq_wait: Condition = last_put_seq_wait + self._last_ack_seq: Synchronized[int] = last_ack_seq + self._last_ack_seq_wait: Condition = last_ack_seq_wait self._max_queue_size: int = max_queue_size self._mode: Literal["async", "disabled"] = mode def run(self) -> None: logger.info("Data synchronization started") worker = SyncProcessWorker( + project=self._project, family=self._family, api_token=self._api_token, errors_queue=self._errors_queue, external_operations_queue=self._external_operations_queue, last_put_seq=self._last_put_seq, last_put_seq_wait=self._last_put_seq_wait, + last_ack_seq=self._last_ack_seq, + last_ack_seq_wait=self._last_ack_seq_wait, max_queue_size=self._max_queue_size, mode=self._mode, ) @@ -126,20 +234,25 @@ def __init__( self, *, api_token: str, + project: str, family: str, + mode: Literal["async", "disabled"], errors_queue: ErrorsQueue, external_operations_queue: multiprocessing.Queue[QueueElement], last_put_seq: Synchronized[int], - mode: Literal["async", "disabled"], last_put_seq_wait: Condition, + last_ack_seq: Synchronized[int], + last_ack_seq_wait: Condition, max_queue_size: int = MAX_QUEUE_SIZE, ) -> None: self._errors_queue = errors_queue self._internal_operations_queue: queue.Queue[QueueElement] = queue.Queue(maxsize=max_queue_size) - self._sync_thread = SyncThread( + self._status_tracking_queue: PeekableQueue[StatusTrackingElement] = PeekableQueue() + self._sync_thread = SenderThread( api_token=api_token, operations_queue=self._internal_operations_queue, + status_tracking_queue=self._status_tracking_queue, errors_queue=self._errors_queue, family=family, last_put_seq=last_put_seq, @@ -151,14 +264,23 @@ def __init__( internal=self._internal_operations_queue, errors_queue=self._errors_queue, ) + self._status_tracking_thread = StatusTrackingThread( + api_token=api_token, + mode=mode, + project=project, + errors_queue=self._errors_queue, + status_tracking_queue=self._status_tracking_queue, + last_ack_seq=last_ack_seq, + last_ack_seq_wait=last_ack_seq_wait, + ) @property def threads(self) -> tuple[Daemon, ...]: - return self._external_to_internal_thread, self._sync_thread + return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread @property def resources(self) -> tuple[Resource, ...]: - return self._external_to_internal_thread, self._sync_thread + return self._external_to_internal_thread, self._sync_thread, self._status_tracking_thread def interrupt(self) -> None: for thread in self.threads: @@ -214,41 +336,30 @@ def work(self) -> None: self._errors_queue.put(e) -def raise_for_status(response: Response[RequestId]) -> None: - if response.status_code == 403: - raise NeptuneUnauthorizedError() - if response.status_code != 200: - raise RuntimeError(f"Unexpected status code: {response.status_code}") - - -def _ensure_backend_initialized(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient: - if mode == "disabled": - return MockedApiClient() - return HostedApiClient(api_token=api_token) - - -class SyncThread(Daemon, WithResources): +class SenderThread(Daemon, WithResources): def __init__( self, api_token: str, operations_queue: queue.Queue[QueueElement], + status_tracking_queue: PeekableQueue[StatusTrackingElement], errors_queue: ErrorsQueue, family: str, last_put_seq: Synchronized[int], last_put_seq_wait: Condition, mode: Literal["async", "disabled"], ) -> None: - super().__init__(name="SyncThread", sleep_time=SYNC_THREAD_SLEEP_TIME) + super().__init__(name="SenderThread", sleep_time=SYNC_THREAD_SLEEP_TIME) self._api_token: str = api_token self._operations_queue: queue.Queue[QueueElement] = operations_queue + self._status_tracking_queue: PeekableQueue[StatusTrackingElement] = status_tracking_queue self._errors_queue: ErrorsQueue = errors_queue - self._backend: Optional[ApiClient] = None self._family: str = family self._last_put_seq: Synchronized[int] = last_put_seq self._last_put_seq_wait: Condition = last_put_seq_wait self._mode: Literal["async", "disabled"] = mode + self._backend: Optional[ApiClient] = None self._latest_unprocessed: Optional[QueueElement] = None def get_next(self) -> Optional[QueueElement]: @@ -268,12 +379,18 @@ def resources(self) -> tuple[Resource, ...]: @backoff.on_exception(backoff.expo, NeptuneConnectionLostError, max_time=OPERATION_TIMEOUT) @with_api_errors_handling - def submit(self, *, operation: RunOperation) -> None: + def submit(self, *, operation: RunOperation) -> Optional[RequestId]: if self._backend is None: - self._backend = _ensure_backend_initialized(api_token=self._api_token, mode=self._mode) + self._backend = backend_factory(api_token=self._api_token, mode=self._mode) response = self._backend.submit(operation=operation, family=self._family) logger.debug("Server response:", response) - raise_for_status(response) + + if response.status_code == 403: + raise NeptuneUnauthorizedError() + if response.status_code != 200: + raise RuntimeError(f"Unexpected status code: {response.status_code}") + + return response.parsed def work(self) -> None: while (operation := self.get_next()) is not None: @@ -283,7 +400,11 @@ def work(self) -> None: try: run_operation = RunOperation() run_operation.ParseFromString(data) - self.submit(operation=run_operation) + request_id = self.submit(operation=run_operation) + if request_id: + self._status_tracking_queue.put( + StatusTrackingElement(sequence_id=sequence_id, request_id=request_id.value) + ) except NeptuneRetryableError as e: self._errors_queue.put(e) continue @@ -299,3 +420,100 @@ def work(self) -> None: with self._last_put_seq_wait: self._last_put_seq.value = sequence_id self._last_put_seq_wait.notify_all() + + +class StatusTrackingThread(Daemon, WithResources): + def __init__( + self, + api_token: str, + mode: Literal["async", "disabled"], + project: str, + errors_queue: ErrorsQueue, + status_tracking_queue: PeekableQueue[StatusTrackingElement], + last_ack_seq: Synchronized[int], + last_ack_seq_wait: Condition, + ) -> None: + super().__init__(name="StatusTrackingThread", sleep_time=STATUS_TRACKING_THREAD_SLEEP_TIME) + + self._api_token: str = api_token + self._mode: Literal["async", "disabled"] = mode + self._project: str = project + self._errors_queue: ErrorsQueue = errors_queue + self._status_tracking_queue: PeekableQueue[StatusTrackingElement] = status_tracking_queue + self._last_ack_seq: Synchronized[int] = last_ack_seq + self._last_ack_seq_wait: Condition = last_ack_seq_wait + + self._backend: Optional[ApiClient] = None + + @property + def resources(self) -> tuple[Resource, ...]: + if self._backend is not None: + return (self._backend,) + return () + + def get_next(self) -> Optional[List[StatusTrackingElement]]: + try: + return self._status_tracking_queue.peek(max_size=MAX_REQUESTS_STATUS_BATCH_SIZE) + except queue.Empty: + return None + + @backoff.on_exception(backoff.expo, NeptuneConnectionLostError, max_time=OPERATION_TIMEOUT) + @with_api_errors_handling + def check_batch(self, *, request_ids: List[str]) -> Optional[BulkRequestStatus]: + if self._backend is None: + self._backend = backend_factory(api_token=self._api_token, mode=self._mode) + + response = self._backend.check_batch(request_ids=request_ids, project=self._project) + logger.debug("Server response:", response) + + if response.status_code == 403: + raise NeptuneUnauthorizedError() + if response.status_code != 200: + raise RuntimeError(f"Unexpected status code: {response.status_code}") + + return response.parsed + + def work(self) -> None: + while (batch := self.get_next()) is not None: + request_ids = [element.request_id for element in batch] + sequence_ids = [element.sequence_id for element in batch] + + try: + response = self.check_batch(request_ids=request_ids) + if response is None: + break + except NeptuneRetryableError as e: + self._errors_queue.put(e) + break + except Exception as e: + self._errors_queue.put(e) + self.interrupt() + self._last_ack_seq_wait.notify_all() + break + + to_commit = 0 + sequence_id = None + for request_status, request_seq_id in zip(response.statuses, sequence_ids): + codes = [code_status.code for code_status in request_status.code_by_count] + if Code.UNAVAILABLE in codes: + break + + detailed_codes = [code_status.detail for code_status in request_status.code_by_count] + for detailed_code in detailed_codes: + if (error := code_to_exception(detailed_code)) is not None: + self._errors_queue.put(error()) + + to_commit += 1 + sequence_id = request_seq_id + + if to_commit > 0: + self._status_tracking_queue.commit(to_commit) + + # Update Last ACK sequence id and notify threads in the main process + if sequence_id is not None: + with self._last_ack_seq_wait: + self._last_ack_seq.value = sequence_id + self._last_ack_seq_wait.notify_all() + else: + # Sleep before retry + break diff --git a/src/neptune_scale/exceptions.py b/src/neptune_scale/exceptions.py index b43bd290..30a59692 100644 --- a/src/neptune_scale/exceptions.py +++ b/src/neptune_scale/exceptions.py @@ -2,6 +2,7 @@ __all__ = ( "NeptuneScaleError", + "NeptuneScaleWarning", "NeptuneOperationsQueueMaxSizeExceeded", "NeptuneUnauthorizedError", "NeptuneInvalidCredentialsError", @@ -9,6 +10,26 @@ "NeptuneConnectionLostError", "NeptuneUnableToAuthenticateError", "NeptuneRetryableError", + "NeptuneProjectNotFound", + "NeptuneProjectInvalidName", + "NeptuneRunNotFound", + "NeptuneRunDuplicate", + "NeptuneRunConflicting", + "NeptuneRunForkParentNotFound", + "NeptuneRunInvalidCreationParameters", + "NeptuneFieldPathExceedsSizeLimit", + "NeptuneFieldPathEmpty", + "NeptuneFieldPathInvalid", + "NeptuneFieldPathNonWritable", + "NeptuneFieldTypeUnsupported", + "NeptuneFieldTypeConflicting", + "NeptuneSeriesPointDuplicate", + "NeptuneSeriesStepNonIncreasing", + "NeptuneSeriesStepNotAfterForkPoint", + "NeptuneSeriesTimestampDecreasing", + "NeptuneFloatValueNanInfUnsupported", + "NeptuneStringValueExceedsSizeLimit", + "NeptuneStringSetExceedsSizeLimit", ) from typing import Any @@ -27,6 +48,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(self.message.format(*args, **STYLES, **kwargs)) +class NeptuneScaleWarning(Warning): + message = "A warning occurred in the Neptune Scale client." + + def __init__(self, *args: Any, **kwargs: Any) -> None: + ensure_style_detected() + super().__init__(self.message.format(*args, **STYLES, **kwargs)) + + class NeptuneOperationsQueueMaxSizeExceeded(NeptuneScaleError): message = """ {h1} @@ -153,3 +182,265 @@ class NeptuneUnableToAuthenticateError(NeptuneScaleError): Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. """ + + +class NeptuneProjectNotFound(NeptuneScaleError): + message = """ +{h1} +----NeptuneProjectNotFound----------------------------------------------------- +{end} +Project not found. Either the project hasn't been created yet or the name is incorrect. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneProjectInvalidName(NeptuneScaleError): + message = """ +{h1} +----NeptuneProjectInvalidName-------------------------------------------------- +{end} +Project name is either empty or too long. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneRunNotFound(NeptuneScaleError): + message = """ +{h1} +----NeptuneRunNotFound--------------------------------------------------------- +{end} +Run not found. May happen when the run is not yet created. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneRunDuplicate(NeptuneScaleWarning): + message = """ +{h1} +----NeptuneRunDuplicate-------------------------------------------------------- +{end} +Identical run already exists. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneRunConflicting(NeptuneScaleError): + message = """ +{h1} +----NeptuneRunConflicting------------------------------------------------------ +{end} +Run with specified `run_id` already exists, but has different creation parameters (`family` or `from_run_id`). + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneRunForkParentNotFound(NeptuneScaleWarning): + message = """ +{h1} +----NeptuneRunForkParentNotFound----------------------------------------------- +{end} +Missing fork parent run. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneRunInvalidCreationParameters(NeptuneScaleError): + message = """ +{h1} +----NeptuneRunInvalidCreationParameters---------------------------------------- +{end} +Invalid run creation parameters. For example, the experiment name is too large. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneFieldPathExceedsSizeLimit(NeptuneScaleError): + message = """ +{h1} +----NeptuneFieldPathExceedsSizeLimit------------------------------------------- +{end} +Field path is too long. Maximum length is 1024 bytes (not characters). + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneFieldPathEmpty(NeptuneScaleError): + message = """ +{h1} +----NeptuneFieldPathEmpty------------------------------------------------------ +{end} +Field path is empty. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneFieldPathInvalid(NeptuneScaleError): + message = """ +{h1} +----NeptuneFieldPathInvalid---------------------------------------------------- +{end} +Field path is invalid. To troubleshoot the problem, ensure that the UTF-8 encoding is valid. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneFieldPathNonWritable(NeptuneScaleError): + message = """ +{h1} +----NeptuneFieldPathNonWritable------------------------------------------------ +{end} +Field path is non-writable. Some special sys/ fields are read-only. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneFieldTypeUnsupported(NeptuneScaleError): + message = """ +{h1} +----NeptuneFieldTypeUnsupported------------------------------------------------ +{end} +Field type is not supported by the system. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneFieldTypeConflicting(NeptuneScaleError): + message = """ +{h1} +----NeptuneFieldTypeConflicting------------------------------------------------ +{end} +Field type is different from the one that was previously logged for this series. +Once a field type is set, it cannot be changed. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneSeriesPointDuplicate(NeptuneScaleWarning): + message = """ +{h1} +----NeptuneSeriesPointDuplicate------------------------------------------------ +{end} +The exact same data point was already logged for this series. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneSeriesStepNonIncreasing(NeptuneScaleError): + message = """ +{h1} +----NeptuneSeriesStepNonIncreasing--------------------------------------------- +{end} +The step of a series value is smaller than the most recently logged step for this series or the step is exactly the same, + but the value is different. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneSeriesStepNotAfterForkPoint(NeptuneScaleError): + message = """ +{h1} +----NeptuneSeriesStepNotAfterForkPoint----------------------------------------- +{end} +The series value must be greater than the step specified by the `from_step` argument. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneSeriesTimestampDecreasing(NeptuneScaleError): + message = """ +{h1} +----NeptuneSeriesTimestampDecreasing------------------------------------------- +{end} +The timestamp of a series value is less than the most recently logged value. Identical timestamps are allowed. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneFloatValueNanInfUnsupported(NeptuneScaleError): + message = """ +{h1} +----NeptuneFloatValueNanInfUnsupported----------------------------------------- +{end} +Unsupported value type for float64 field or float64 series. Applies to Inf and NaN values. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneStringValueExceedsSizeLimit(NeptuneScaleError): + message = """ +{h1} +----NeptuneStringValueExceedsSizeLimit----------------------------------------- +{end} +String value is too long. Maximum length is 64KB. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" + + +class NeptuneStringSetExceedsSizeLimit(NeptuneScaleError): + message = """ +{h1} +----NeptuneStringSetExceedsSizeLimit------------------------------------------- +{end} +String Set value is too long. Maximum length is 64KB. + +{correct}Need help?{end}-> https://docs.neptune.ai/getting_help + +Struggling with the formatting? To disable it, set the `NEPTUNE_DISABLE_COLORS` environment variable to `True`. +""" diff --git a/src/neptune_scale/parameters.py b/src/neptune_scale/parameters.py index 3c7d65b4..df0ed52b 100644 --- a/src/neptune_scale/parameters.py +++ b/src/neptune_scale/parameters.py @@ -4,10 +4,13 @@ MAX_MULTIPROCESSING_QUEUE_SIZE = 32767 MAX_QUEUE_ELEMENT_SIZE = 1024 * 1024 # 1MB SYNC_THREAD_SLEEP_TIME = 0.1 +STATUS_TRACKING_THREAD_SLEEP_TIME = 1 EXTERNAL_TO_INTERNAL_THREAD_SLEEP_TIME = 0.1 ERRORS_MONITOR_THREAD_SLEEP_TIME = 0.1 SHUTDOWN_TIMEOUT = 60 # 1 minute MINIMAL_WAIT_FOR_PUT_SLEEP_TIME = 10 +MINIMAL_WAIT_FOR_ACK_SLEEP_TIME = 10 STOP_MESSAGE_FREQUENCY = 5 REQUEST_TIMEOUT = 5 OPERATION_TIMEOUT = 60 +MAX_REQUESTS_STATUS_BATCH_SIZE = 1000