Skip to content

chore: Request status tracking #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
args: [ --config-file, pyproject.toml ]
pass_filenames: false
additional_dependencies:
- neptune-api==0.4.0
- neptune-api==0.6.0
- more-itertools
- backoff
default_language_version:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pattern = "default-unprefixed"
[tool.poetry.dependencies]
python = "^3.8"

neptune-api = "0.4.0"
neptune-api = "0.6.0"
more-itertools = "^10.0.0"
psutil = "^5.0.0"
backoff = "^2.0.0"
Expand Down Expand Up @@ -77,10 +77,10 @@ force_grid_wrap = 2
[tool.ruff]
line-length = 120
target-version = "py38"
ignore = ["UP006", "UP007"]

[tool.ruff.lint]
select = ["F", "UP"]
ignore = ["UP006", "UP007"]

[tool.mypy]
files = 'src/neptune_scale'
Expand Down
133 changes: 100 additions & 33 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from multiprocessing.sharedctypes import Synchronized
from multiprocessing.synchronize import Condition as ConditionT
from typing import (
Any,
Callable,
Dict,
List,
Expand Down Expand Up @@ -60,6 +61,7 @@
MAX_FAMILY_LENGTH,
MAX_QUEUE_SIZE,
MAX_RUN_ID_LENGTH,
MINIMAL_WAIT_FOR_ACK_SLEEP_TIME,
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
STOP_MESSAGE_FREQUENCY,
)
Expand Down Expand Up @@ -172,15 +174,23 @@ def __init__(
max_queue_size_exceeded_callback=max_queue_size_exceeded_callback,
on_network_error_callback=on_network_error_callback,
)

self._last_put_seq: Synchronized[int] = multiprocessing.Value("i", -1)
self._last_put_seq_wait: ConditionT = multiprocessing.Condition()

self._last_ack_seq: Synchronized[int] = multiprocessing.Value("i", -1)
self._last_ack_seq_wait: ConditionT = multiprocessing.Condition()

self._sync_process = SyncProcess(
project=self._project,
family=self._family,
operations_queue=self._operations_queue.queue,
errors_queue=self._errors_queue,
api_token=input_api_token,
last_put_seq=self._last_put_seq,
last_put_seq_wait=self._last_put_seq_wait,
last_ack_seq=self._last_ack_seq,
last_ack_seq_wait=self._last_ack_seq_wait,
max_queue_size=max_queue_size,
mode=mode,
)
Expand All @@ -198,6 +208,7 @@ def __init__(
from_run_id=from_run_id,
from_step=from_step,
)
self.wait_for_processing(verbose=False)

@property
def resources(self) -> tuple[Resource, ...]:
Expand All @@ -208,10 +219,9 @@ def resources(self) -> tuple[Resource, ...]:
)

def _close(self) -> None:
# TODO: Change to wait for all operations to be processed
with self._lock:
if self._sync_process.is_alive():
self.wait_for_submission()
self.wait_for_processing()
self._sync_process.terminate()
self._sync_process.join()

Expand Down Expand Up @@ -320,49 +330,106 @@ def log(
for operation in splitter:
self._operations_queue.enqueue(operation=operation)

def wait_for_submission(self, timeout: Optional[float] = None) -> None:
"""
Waits until all metadata is submitted to Neptune.
"""
begin_time = time.time()
logger.info("Waiting for all operations to be processed")
if timeout is None:
def _wait(
self,
phrase: str,
sleep_time: float,
wait_condition: ConditionT,
external_value: Synchronized[int],
timeout: Optional[float] = None,
verbose: bool = True,
) -> None:
if verbose:
logger.info(f"Waiting for all operations to be {phrase}")

if timeout is None and verbose:
logger.warning("No timeout specified. Waiting indefinitely")

with self._lock:
if not self._sync_process.is_alive():
logger.warning("Sync process is not running")
if verbose:
logger.warning("Sync process is not running")
return # No need to wait if the sync process is not running

sleep_time_wait = (
min(MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, timeout) if timeout is not None else MINIMAL_WAIT_FOR_PUT_SLEEP_TIME
)
begin_time = time.time()
wait_time = min(sleep_time, timeout) if timeout is not None else sleep_time
last_queued_sequence_id = self._operations_queue.last_sequence_id
last_message_printed: Optional[float] = None
last_print_timestamp: Optional[float] = None

while True:
with self._last_put_seq_wait:
self._last_put_seq_wait.wait(timeout=sleep_time_wait)
value = self._last_put_seq.value
with wait_condition:
wait_condition.wait(timeout=wait_time)
value = external_value.value

if value == -1:
if self._operations_queue.last_sequence_id != -1:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting. No operations processed yet. Operations to sync: %s",
self._operations_queue.last_sequence_id + 1,
)
last_print_timestamp = print_message(
f"Waiting. No operations were {phrase} yet. Operations to sync: %s",
self._operations_queue.last_sequence_id + 1,
last_print=last_print_timestamp,
verbose=verbose,
)
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info("Waiting. No operations processed yet")
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
last_message_printed = time.time()
logger.info(
"Waiting for remaining %d operation(s) to be synced",
last_queued_sequence_id - value + 1,
last_print_timestamp = print_message(
f"Waiting. No operations were {phrase} yet",
last_print=last_print_timestamp,
verbose=verbose,
)
else:
last_print_timestamp = print_message(
f"Waiting for remaining %d operation(s) to be {phrase}",
last_queued_sequence_id - value + 1,
last_print=last_print_timestamp,
verbose=verbose,
)

# Reaching the last queued sequence ID means that all operations were submitted
if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout):
break

logger.info("All operations processed")
if verbose:
logger.info(f"All operations were {phrase}")

def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> None:
"""
Waits until all metadata is submitted to Neptune.

Args:
timeout (float, optional): In seconds, the maximum time to wait for submission.
verbose (bool): If True (default), prints messages about the waiting process.
"""
self._wait(
phrase="submitted",
sleep_time=MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
wait_condition=self._last_put_seq_wait,
external_value=self._last_put_seq,
timeout=timeout,
verbose=verbose,
)

def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> None:
"""
Waits until all metadata is processed by Neptune.

Args:
timeout (float, optional): In seconds, the maximum time to wait for processing.
verbose (bool): If True (default), prints messages about the waiting process.
"""
self._wait(
phrase="processed",
sleep_time=MINIMAL_WAIT_FOR_ACK_SLEEP_TIME,
wait_condition=self._last_ack_seq_wait,
external_value=self._last_ack_seq,
timeout=timeout,
verbose=verbose,
)


def print_message(msg: str, *args: Any, last_print: Optional[float] = None, verbose: bool = True) -> Optional[float]:
current_time = time.time()

if verbose and (last_print is None or current_time - last_print > STOP_MESSAGE_FREQUENCY):
logger.info(msg, *args)
return current_time

return last_print
50 changes: 46 additions & 4 deletions src/neptune_scale/api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,43 @@
#
from __future__ import annotations

__all__ = ("HostedApiClient", "MockedApiClient", "ApiClient")
__all__ = ("HostedApiClient", "MockedApiClient", "ApiClient", "backend_factory")

import abc
import os
import uuid
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any
from typing import (
Any,
Literal,
)

from httpx import Timeout
from neptune_api import (
AuthenticatedClient,
Client,
)
from neptune_api.api.backend import get_client_config
from neptune_api.api.data_ingestion import submit_operation
from neptune_api.api.data_ingestion import (
check_request_status_bulk,
submit_operation,
)
from neptune_api.auth_helpers import exchange_api_key
from neptune_api.credentials import Credentials
from neptune_api.models import (
ClientConfig,
Error,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import RequestId
from neptune_api.proto.google_rpc.code_pb2 import Code
from neptune_api.proto.neptune_pb.ingest.v1.ingest_pb2 import IngestCode
from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import (
BulkRequestStatus,
RequestId,
RequestIdList,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation
from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus
from neptune_api.types import Response

from neptune_scale.core.components.abstract import Resource
Expand Down Expand Up @@ -95,6 +108,9 @@ class ApiClient(Resource, abc.ABC):
@abc.abstractmethod
def submit(self, operation: RunOperation, family: str) -> Response[RequestId]: ...

@abc.abstractmethod
def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ...


class HostedApiClient(ApiClient):
def __init__(self, api_token: str) -> None:
Expand All @@ -112,6 +128,13 @@ def __init__(self, api_token: str) -> None:
def submit(self, operation: RunOperation, family: str) -> Response[RequestId]:
return submit_operation.sync_detailed(client=self._backend, body=operation, family=family)

def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]:
return check_request_status_bulk.sync_detailed(
client=self._backend,
project_identifier=project,
body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]),
)

def close(self) -> None:
logger.debug("Closing API client")
self._backend.__exit__()
Expand All @@ -123,3 +146,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

def submit(self, operation: RunOperation, family: str) -> Response[RequestId]:
return Response(content=b"", parsed=RequestId(value=str(uuid.uuid4())), status_code=HTTPStatus.OK, headers={})

def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]:
response_body = BulkRequestStatus(
statuses=list(
map(
lambda _: RequestStatus(
code_by_count=[RequestStatus.CodeByCount(count=1, code=Code.OK, detail=IngestCode.OK)]
),
request_ids,
)
)
)
return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={})


def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient:
if mode == "disabled":
return MockedApiClient()
return HostedApiClient(api_token=api_token)
9 changes: 9 additions & 0 deletions src/neptune_scale/core/components/errors_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
NeptuneConnectionLostError,
NeptuneOperationsQueueMaxSizeExceeded,
NeptuneScaleError,
NeptuneScaleWarning,
NeptuneUnexpectedError,
)
from neptune_scale.parameters import ERRORS_MONITOR_THREAD_SLEEP_TIME
Expand Down Expand Up @@ -51,13 +52,18 @@ def default_max_queue_size_exceeded_callback(error: BaseException) -> None:
logger.warning(error)


def default_warning_callback(error: BaseException) -> None:
logger.warning(error)


class ErrorsMonitor(Daemon, Resource):
def __init__(
self,
errors_queue: ErrorsQueue,
max_queue_size_exceeded_callback: Optional[Callable[[BaseException], None]] = None,
on_network_error_callback: Optional[Callable[[BaseException], None]] = None,
on_error_callback: Optional[Callable[[BaseException], None]] = None,
on_warning_callback: Optional[Callable[[BaseException], None]] = None,
):
super().__init__(name="ErrorsMonitor", sleep_time=ERRORS_MONITOR_THREAD_SLEEP_TIME)

Expand All @@ -69,6 +75,7 @@ def __init__(
on_network_error_callback or default_network_error_callback
)
self._on_error_callback: Callable[[BaseException], None] = on_error_callback or default_error_callback
self._on_warning_callback: Callable[[BaseException], None] = on_warning_callback or default_warning_callback

def get_next(self) -> Optional[BaseException]:
try:
Expand All @@ -82,6 +89,8 @@ def work(self) -> None:
self._max_queue_size_exceeded_callback(error)
elif isinstance(error, NeptuneConnectionLostError):
self._non_network_error_callback(error)
elif isinstance(error, NeptuneScaleWarning):
self._on_warning_callback(error)
elif isinstance(error, NeptuneScaleError):
self._on_error_callback(error)
else:
Expand Down
Loading
Loading