Skip to content

chore: Request status tracking #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
args: [ --config-file, pyproject.toml ]
pass_filenames: false
additional_dependencies:
- neptune-api==0.4.0
- neptune-api==0.6.0
- more-itertools
- backoff
default_language_version:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pattern = "default-unprefixed"
[tool.poetry.dependencies]
python = "^3.8"

neptune-api = "0.4.0"
neptune-api = "0.6.0"
more-itertools = "^10.0.0"
psutil = "^5.0.0"
backoff = "^2.0.0"
Expand Down Expand Up @@ -77,10 +77,10 @@ force_grid_wrap = 2
[tool.ruff]
line-length = 120
target-version = "py38"
ignore = ["UP006", "UP007"]

[tool.ruff.lint]
select = ["F", "UP"]
ignore = ["UP006", "UP007"]

[tool.mypy]
files = 'src/neptune_scale'
Expand Down
105 changes: 82 additions & 23 deletions src/neptune_scale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
MAX_FAMILY_LENGTH,
MAX_QUEUE_SIZE,
MAX_RUN_ID_LENGTH,
MINIMAL_WAIT_FOR_ACK_SLEEP_TIME,
MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
STOP_MESSAGE_FREQUENCY,
)
Expand Down Expand Up @@ -172,15 +173,23 @@ def __init__(
max_queue_size_exceeded_callback=max_queue_size_exceeded_callback,
on_network_error_callback=on_network_error_callback,
)

self._last_put_seq: Synchronized[int] = multiprocessing.Value("i", -1)
self._last_put_seq_wait: ConditionT = multiprocessing.Condition()

self._last_ack_seq: Synchronized[int] = multiprocessing.Value("i", -1)
self._last_ack_seq_wait: ConditionT = multiprocessing.Condition()

self._sync_process = SyncProcess(
project=self._project,
family=self._family,
operations_queue=self._operations_queue.queue,
errors_queue=self._errors_queue,
api_token=input_api_token,
last_put_seq=self._last_put_seq,
last_put_seq_wait=self._last_put_seq_wait,
last_ack_seq=self._last_ack_seq,
last_ack_seq_wait=self._last_ack_seq_wait,
max_queue_size=max_queue_size,
mode=mode,
)
Expand All @@ -198,6 +207,7 @@ def __init__(
from_run_id=from_run_id,
from_step=from_step,
)
self.wait_for_processing(verbose=False)

@property
def resources(self) -> tuple[Resource, ...]:
Expand All @@ -208,10 +218,9 @@ def resources(self) -> tuple[Resource, ...]:
)

def _close(self) -> None:
# TODO: Change to wait for all operations to be processed
with self._lock:
if self._sync_process.is_alive():
self.wait_for_submission()
self.wait_for_processing()
self._sync_process.terminate()
self._sync_process.join()

Expand Down Expand Up @@ -320,49 +329,99 @@ def log(
for operation in splitter:
self._operations_queue.enqueue(operation=operation)

def wait_for_submission(self, timeout: Optional[float] = None) -> None:
"""
Waits until all metadata is submitted to Neptune.
"""
begin_time = time.time()
logger.info("Waiting for all operations to be processed")
if timeout is None:
def _wait(
self,
phrase: str,
sleep_time: float,
wait_condition: ConditionT,
external_value: Synchronized[int],
timeout: Optional[float] = None,
verbose: bool = True,
) -> None:
if verbose:
logger.info(f"Waiting for all operations to be {phrase}")

if timeout is None and verbose:
logger.warning("No timeout specified. Waiting indefinitely")

with self._lock:
if not self._sync_process.is_alive():
logger.warning("Sync process is not running")
if verbose:
logger.warning("Sync process is not running")
return # No need to wait if the sync process is not running

sleep_time_wait = (
min(MINIMAL_WAIT_FOR_PUT_SLEEP_TIME, timeout) if timeout is not None else MINIMAL_WAIT_FOR_PUT_SLEEP_TIME
)
begin_time = time.time()
wait_time = min(sleep_time, timeout) if timeout is not None else sleep_time
last_queued_sequence_id = self._operations_queue.last_sequence_id
last_message_printed: Optional[float] = None

while True:
with self._last_put_seq_wait:
self._last_put_seq_wait.wait(timeout=sleep_time_wait)
value = self._last_put_seq.value
with wait_condition:
wait_condition.wait(timeout=wait_time)
value = external_value.value

if value == -1:
if self._operations_queue.last_sequence_id != -1:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
if verbose and should_print_message(last_message_printed):
last_message_printed = time.time()
logger.info(
"Waiting. No operations processed yet. Operations to sync: %s",
f"Waiting. No operations were {phrase} yet. Operations to sync: %s",
self._operations_queue.last_sequence_id + 1,
)
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
if verbose and should_print_message(last_message_printed):
last_message_printed = time.time()
logger.info("Waiting. No operations processed yet")
logger.info(f"Waiting. No operations were {phrase} yet")
else:
if last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY:
if verbose and should_print_message(last_message_printed):
Copy link
Contributor

@kgodlewski kgodlewski Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny detail, but in the future we could do:

def print_message(msg: str, last_print_timestamp: Optional[float], verbose: bool) -> float, the func would return the current time if it printed the message. So we could just do:

last_message_printed = print_message(f"some message", last_message_printed, verbose)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

last_message_printed = time.time()
logger.info(
"Waiting for remaining %d operation(s) to be synced",
f"Waiting for remaining %d operation(s) to be {phrase}",
last_queued_sequence_id - value + 1,
)

# Reaching the last queued sequence ID means that all operations were submitted
if value >= last_queued_sequence_id or (timeout is not None and time.time() - begin_time > timeout):
break

logger.info("All operations processed")
if verbose:
logger.info(f"All operations were {phrase}")

def wait_for_submission(self, timeout: Optional[float] = None, verbose: bool = True) -> None:
"""
Waits until all metadata is submitted to Neptune.

Args:
timeout (float, optional): In seconds, the maximum time to wait for submission.
verbose (bool): If True (default), prints messages about the waiting process.
"""
self._wait(
phrase="submitted",
sleep_time=MINIMAL_WAIT_FOR_PUT_SLEEP_TIME,
wait_condition=self._last_put_seq_wait,
external_value=self._last_put_seq,
timeout=timeout,
verbose=verbose,
)

def wait_for_processing(self, timeout: Optional[float] = None, verbose: bool = True) -> None:
"""
Waits until all metadata is processed by Neptune.

Args:
timeout (float, optional): In seconds, the maximum time to wait for processing.
verbose (bool): If True (default), prints messages about the waiting process.
"""
self._wait(
phrase="processed",
sleep_time=MINIMAL_WAIT_FOR_ACK_SLEEP_TIME,
wait_condition=self._last_ack_seq_wait,
external_value=self._last_ack_seq,
timeout=timeout,
verbose=verbose,
)


def should_print_message(last_message_printed: Optional[float]) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kgodlewski I've done the refactor requested by you in previous PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this user-facing? If so, what's the purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not 😉

"""Check if enough time has passed to print a message."""
return last_message_printed is None or time.time() - last_message_printed > STOP_MESSAGE_FREQUENCY
50 changes: 46 additions & 4 deletions src/neptune_scale/api/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,43 @@
#
from __future__ import annotations

__all__ = ("HostedApiClient", "MockedApiClient", "ApiClient")
__all__ = ("HostedApiClient", "MockedApiClient", "ApiClient", "backend_factory")

import abc
import os
import uuid
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any
from typing import (
Any,
Literal,
)

from httpx import Timeout
from neptune_api import (
AuthenticatedClient,
Client,
)
from neptune_api.api.backend import get_client_config
from neptune_api.api.data_ingestion import submit_operation
from neptune_api.api.data_ingestion import (
check_request_status_bulk,
submit_operation,
)
from neptune_api.auth_helpers import exchange_api_key
from neptune_api.credentials import Credentials
from neptune_api.models import (
ClientConfig,
Error,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import RequestId
from neptune_api.proto.google_rpc.code_pb2 import Code
from neptune_api.proto.neptune_pb.ingest.v1.ingest_pb2 import IngestCode
from neptune_api.proto.neptune_pb.ingest.v1.pub.client_pb2 import (
BulkRequestStatus,
RequestId,
RequestIdList,
)
from neptune_api.proto.neptune_pb.ingest.v1.pub.ingest_pb2 import RunOperation
from neptune_api.proto.neptune_pb.ingest.v1.pub.request_status_pb2 import RequestStatus
from neptune_api.types import Response

from neptune_scale.core.components.abstract import Resource
Expand Down Expand Up @@ -95,6 +108,9 @@ class ApiClient(Resource, abc.ABC):
@abc.abstractmethod
def submit(self, operation: RunOperation, family: str) -> Response[RequestId]: ...

@abc.abstractmethod
def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]: ...


class HostedApiClient(ApiClient):
def __init__(self, api_token: str) -> None:
Expand All @@ -112,6 +128,13 @@ def __init__(self, api_token: str) -> None:
def submit(self, operation: RunOperation, family: str) -> Response[RequestId]:
return submit_operation.sync_detailed(client=self._backend, body=operation, family=family)

def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]:
return check_request_status_bulk.sync_detailed(
client=self._backend,
project_identifier=project,
body=RequestIdList(ids=[RequestId(value=request_id) for request_id in request_ids]),
)

def close(self) -> None:
logger.debug("Closing API client")
self._backend.__exit__()
Expand All @@ -123,3 +146,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

def submit(self, operation: RunOperation, family: str) -> Response[RequestId]:
return Response(content=b"", parsed=RequestId(value=str(uuid.uuid4())), status_code=HTTPStatus.OK, headers={})

def check_batch(self, request_ids: list[str], project: str) -> Response[BulkRequestStatus]:
response_body = BulkRequestStatus(
statuses=list(
map(
lambda _: RequestStatus(
code_by_count=[RequestStatus.CodeByCount(count=1, code=Code.OK, detail=IngestCode.OK)]
),
request_ids,
)
)
)
return Response(content=b"", parsed=response_body, status_code=HTTPStatus.OK, headers={})


def backend_factory(api_token: str, mode: Literal["async", "disabled"]) -> ApiClient:
if mode == "disabled":
return MockedApiClient()
return HostedApiClient(api_token=api_token)
9 changes: 9 additions & 0 deletions src/neptune_scale/core/components/errors_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
NeptuneConnectionLostError,
NeptuneOperationsQueueMaxSizeExceeded,
NeptuneScaleError,
NeptuneScaleWarning,
NeptuneUnexpectedError,
)
from neptune_scale.parameters import ERRORS_MONITOR_THREAD_SLEEP_TIME
Expand Down Expand Up @@ -51,13 +52,18 @@ def default_max_queue_size_exceeded_callback(error: BaseException) -> None:
logger.warning(error)


def default_warning_callback(error: BaseException) -> None:
logger.warning(error)


class ErrorsMonitor(Daemon, Resource):
def __init__(
self,
errors_queue: ErrorsQueue,
max_queue_size_exceeded_callback: Optional[Callable[[BaseException], None]] = None,
on_network_error_callback: Optional[Callable[[BaseException], None]] = None,
on_error_callback: Optional[Callable[[BaseException], None]] = None,
on_warning_callback: Optional[Callable[[BaseException], None]] = None,
):
super().__init__(name="ErrorsMonitor", sleep_time=ERRORS_MONITOR_THREAD_SLEEP_TIME)

Expand All @@ -69,6 +75,7 @@ def __init__(
on_network_error_callback or default_network_error_callback
)
self._on_error_callback: Callable[[BaseException], None] = on_error_callback or default_error_callback
self._on_warning_callback: Callable[[BaseException], None] = on_warning_callback or default_warning_callback

def get_next(self) -> Optional[BaseException]:
try:
Expand All @@ -82,6 +89,8 @@ def work(self) -> None:
self._max_queue_size_exceeded_callback(error)
elif isinstance(error, NeptuneConnectionLostError):
self._non_network_error_callback(error)
elif isinstance(error, NeptuneScaleWarning):
self._on_warning_callback(error)
elif isinstance(error, NeptuneScaleError):
self._on_error_callback(error)
else:
Expand Down
Loading
Loading